Combine AES key wrap/unwrap callbacks

This commit is contained in:
Josh Holtrop
2025-07-22 16:34:37 -04:00
parent 525f1cc39e
commit 27f0ef8789
4 changed files with 28 additions and 68 deletions

View File

@@ -11,11 +11,12 @@
\param[in] keySz Size of the key to use.
\param[in] in Specify the input data to wrap/unwrap.
\param[in] inSz Size of the input data.
\param[in] wrap 1 if the requested operation is a key wrap, 0 for unwrap.
\param[out] out Specify the output buffer.
\param[out] outSz Size of the output buffer.
*/
typedef int (*CallbackAESKeyWrap)(const byte* key, word32 keySz,
const byte* in, word32 inSz, byte* out, word32 outSz);
typedef int (*CallbackAESKeyWrapUnwrap)(const byte* key, word32 keySz,
const byte* in, word32 inSz, int wrap, byte* out, word32 outSz);
/*!
\ingroup PKCS7
@@ -500,33 +501,16 @@ int wc_PKCS7_VerifySignedData_ex(PKCS7* pkcs7, const byte* hashBuf,
\ingroup PKCS7
\brief Set the callback function to be used to perform a custom AES key
wrap operation.
wrap/unwrap operation.
\retval 0 Callback function was set successfully
\retval BAD_FUNC_ARG Parameter pkcs7 is NULL
\param pkcs7 pointer to the PKCS7 structure
\param aesKeyWrapCb pointer to custom AES key wrap function
\sa wc_PKCS7_SetAESKeyUnwrapCb
\param aesKeyWrapCb pointer to custom AES key wrap/unwrap function
*/
int wc_PKCS7_SetAESKeyWrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyWrapCb);
/*!
\ingroup PKCS7
\brief Set the callback function to be used to perform a custom AES key
unwrap operation.
\retval 0 Callback function was set successfully
\retval BAD_FUNC_ARG Parameter pkcs7 is NULL
\param pkcs7 pointer to the PKCS7 structure
\param aesKeyUnwrapCb pointer to custom AES key unwrap function
\sa wc_PKCS7_SetAESKeyWrapCb
*/
int wc_PKCS7_SetAESKeyUnwrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyUnwrapCb);
int wc_PKCS7_SetAESKeyWrapUnwrapCb(wc_PKCS7* pkcs7,
CallbackAESKeyWrapUnwrap aesKeyWrapCb);
/*!
\ingroup PKCS7

View File

@@ -17999,25 +17999,16 @@ static int test_wc_PKCS7_EncodeDecodeEnvelopedData(void)
static int wasAESKeyWrapCbCalled = 0;
static int wasAESKeyUnwrapCbCalled = 0;
static int testAESKeyWrapCb(const byte* key, word32 keySz,
const byte* in, word32 inSz, byte* out, word32 outSz)
static int testAESKeyWrapUnwrapCb(const byte* key, word32 keySz,
const byte* in, word32 inSz, int wrap, byte* out, word32 outSz)
{
(void)key;
(void)keySz;
wasAESKeyWrapCbCalled = 1;
XMEMSET(out, 0xEE, outSz);
if (inSz <= outSz) {
XMEMCPY(out, in, inSz);
}
return inSz;
}
static int testAESKeyUnwrapCb(const byte* key, word32 keySz,
const byte* in, word32 inSz, byte* out, word32 outSz)
{
(void)key;
(void)keySz;
wasAESKeyUnwrapCbCalled = 1;
(void)wrap;
if (wrap)
wasAESKeyWrapCbCalled = 1;
else
wasAESKeyUnwrapCbCalled = 1;
XMEMSET(out, 0xEE, outSz);
if (inSz <= outSz) {
XMEMCPY(out, in, inSz);
@@ -18104,8 +18095,7 @@ static int test_wc_PKCS7_SetAESKeyWrapUnwrapCb(void)
}
/* Test custom AES key wrap/unwrap callback */
ExpectIntEQ(wc_PKCS7_SetAESKeyWrapCb(pkcs7, testAESKeyWrapCb), 0);
ExpectIntEQ(wc_PKCS7_SetAESKeyUnwrapCb(pkcs7, testAESKeyUnwrapCb), 0);
ExpectIntEQ(wc_PKCS7_SetAESKeyWrapUnwrapCb(pkcs7, testAESKeyWrapUnwrapCb), 0);
ExpectIntGE(wc_PKCS7_EncodeEnvelopedData(pkcs7, output,
(word32)sizeof(output)), 0);

View File

@@ -6838,9 +6838,9 @@ static int wc_PKCS7_KeyWrap(const wc_PKCS7 * pkcs7, byte const * cek,
#endif
if (direction == AES_ENCRYPTION) {
if (pkcs7->aesKeyWrapCb != NULL) {
ret = pkcs7->aesKeyWrapCb(kek, kekSz, cek, cekSz,
out, outSz);
if (pkcs7->aesKeyWrapUnwrapCb != NULL) {
ret = pkcs7->aesKeyWrapUnwrapCb(kek, kekSz, cek, cekSz, 1,
out, outSz);
}
else {
ret = wc_AesKeyWrap(kek, kekSz, cek, cekSz,
@@ -6848,9 +6848,9 @@ static int wc_PKCS7_KeyWrap(const wc_PKCS7 * pkcs7, byte const * cek,
}
} else if (direction == AES_DECRYPTION) {
if (pkcs7->aesKeyUnwrapCb != NULL) {
ret = pkcs7->aesKeyUnwrapCb(kek, kekSz, cek, cekSz,
out, outSz);
if (pkcs7->aesKeyWrapUnwrapCb != NULL) {
ret = pkcs7->aesKeyWrapUnwrapCb(kek, kekSz, cek, cekSz, 0,
out, outSz);
}
else {
ret = wc_AesKeyUnWrap(kek, kekSz, cek, cekSz,
@@ -11094,28 +11094,17 @@ int wc_PKCS7_SetWrapCEKCb(wc_PKCS7* pkcs7, CallbackWrapCEK cb)
/* return 0 on success */
int wc_PKCS7_SetAESKeyWrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyWrapCb)
int wc_PKCS7_SetAESKeyWrapUnwrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrapUnwrap aesKeyWrapUnwrapCb)
{
if (pkcs7 == NULL)
return BAD_FUNC_ARG;
pkcs7->aesKeyWrapCb = aesKeyWrapCb;
pkcs7->aesKeyWrapUnwrapCb = aesKeyWrapUnwrapCb;
return 0;
}
/* return 0 on success */
int wc_PKCS7_SetAESKeyUnwrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyUnwrapCb)
{
if (pkcs7 == NULL)
return BAD_FUNC_ARG;
pkcs7->aesKeyUnwrapCb = aesKeyUnwrapCb;
return 0;
}
/* Decrypt ASN.1 OtherRecipientInfo (ori), as defined by:
*
* OtherRecipientInfo ::= SEQUENCE {

View File

@@ -213,8 +213,8 @@ typedef int (*CallbackWrapCEK)(wc_PKCS7* pkcs7, byte* cek, word32 cekSz,
byte* originKey, word32 originKeySz,
byte* out, word32 outSz,
int keyWrapAlgo, int type, int dir);
typedef int (*CallbackAESKeyWrap)(const byte* key, word32 keySz,
const byte* in, word32 inSz, byte* out, word32 outSz);
typedef int (*CallbackAESKeyWrapUnwrap)(const byte* key, word32 keySz,
const byte* in, word32 inSz, int wrap, byte* out, word32 outSz);
/* Callbacks for supporting different stream cases */
typedef int (*CallbackGetContent)(wc_PKCS7* pkcs7, byte** content, void* ctx);
@@ -373,8 +373,7 @@ struct wc_PKCS7 {
} decryptKey;
#endif
CallbackAESKeyWrap aesKeyWrapCb;
CallbackAESKeyWrap aesKeyUnwrapCb;
CallbackAESKeyWrapUnwrap aesKeyWrapUnwrapCb;
/* !! NEW DATA MEMBERS MUST BE ADDED AT END !! */
};
@@ -503,10 +502,8 @@ WOLFSSL_API int wc_PKCS7_AddRecipient_ORI(wc_PKCS7* pkcs7, CallbackOriEncrypt c
int options);
WOLFSSL_API int wc_PKCS7_SetWrapCEKCb(wc_PKCS7* pkcs7,
CallbackWrapCEK wrapCEKCb);
WOLFSSL_API int wc_PKCS7_SetAESKeyWrapCb(wc_PKCS7* pkcs7,
CallbackAESKeyWrap aesKeyWrapCb);
WOLFSSL_API int wc_PKCS7_SetAESKeyUnwrapCb(wc_PKCS7* pkcs7,
CallbackAESKeyWrap aesKeyUnwrapCb);
WOLFSSL_API int wc_PKCS7_SetAESKeyWrapUnwrapCb(wc_PKCS7* pkcs7,
CallbackAESKeyWrapUnwrap aesKeyWrapUnwrapCb);
#if defined(HAVE_PKCS7_RSA_RAW_SIGN_CALLBACK) && !defined(NO_RSA)
WOLFSSL_API int wc_PKCS7_SetRsaSignRawDigestCb(wc_PKCS7* pkcs7,