diff --git a/wolfcrypt/src/aes.c b/wolfcrypt/src/aes.c index 0875dab25..7d3128b4c 100644 --- a/wolfcrypt/src/aes.c +++ b/wolfcrypt/src/aes.c @@ -2532,7 +2532,7 @@ static void wc_AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) /* Software AES - SetKey */ static int wc_AesSetKeyLocal(Aes* aes, const byte* userKey, word32 keylen, - const byte* iv, int dir) + const byte* iv, int dir, int checkKeyLen) { int ret; word32 *rk; @@ -2545,16 +2545,9 @@ static void wc_AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) word32 localSz = 32; #endif - #if defined(WOLF_CRYPTO_CB) || (defined(WOLFSSL_DEVCRYPTO) && \ - (defined(WOLFSSL_DEVCRYPTO_AES) || defined(WOLFSSL_DEVCRYPTO_CBC))) || \ - (defined(WOLFSSL_ASYNC_CRYPT) && defined(WC_ASYNC_ENABLE_AES)) - #ifdef WOLF_CRYPTO_CB - if (aes->devId != INVALID_DEVID) - #endif - { - XMEMCPY(aes->devKey, userKey, keylen); + if (aes == NULL) { + return BAD_FUNC_ARG; } - #endif #ifdef WOLFSSL_IMX6_CAAM_BLOB if (keylen == (16 + WC_CAAM_BLOB_SZ) || @@ -2570,6 +2563,32 @@ static void wc_AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) } #endif + #if defined(WOLF_CRYPTO_CB) || (defined(WOLFSSL_DEVCRYPTO) && \ + (defined(WOLFSSL_DEVCRYPTO_AES) || defined(WOLFSSL_DEVCRYPTO_CBC))) || \ + (defined(WOLFSSL_ASYNC_CRYPT) && defined(WC_ASYNC_ENABLE_AES)) + #ifdef WOLF_CRYPTO_CB + if (aes->devId != INVALID_DEVID) + #endif + { + if (keylen > sizeof(aes->devKey)) { + return BAD_FUNC_ARG; + } + XMEMCPY(aes->devKey, userKey, keylen); + } + #endif + + if (checkKeyLen) { + if (keylen != 16 && keylen != 24 && keylen != 32) { + return BAD_FUNC_ARG; + } + #ifdef AES_MAX_KEY_SIZE + /* Check key length */ + if (keylen > (AES_MAX_KEY_SIZE / 8)) { + return BAD_FUNC_ARG; + } + #endif + } + #if defined(WOLFSSL_AES_CFB) || defined(WOLFSSL_AES_COUNTER) || \ defined(WOLFSSL_AES_OFB) aes->left = 0; @@ -2599,6 +2618,9 @@ static void wc_AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) } #endif /* WOLFSSL_AESNI */ + if (keylen > sizeof(aes->key)) { + return BAD_FUNC_ARG; + } rk = aes->key; XMEMCPY(rk, userKey, keylen); #if defined(LITTLE_ENDIAN_ORDER) && !defined(WOLFSSL_PIC32MZ_CRYPT) && \ @@ -2788,19 +2810,7 @@ static void wc_AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) int wc_AesSetKey(Aes* aes, const byte* userKey, word32 keylen, const byte* iv, int dir) { - if (aes == NULL || - !((keylen == 16) || (keylen == 24) || (keylen == 32))) { - return BAD_FUNC_ARG; - } - - #if defined(AES_MAX_KEY_SIZE) - /* Check key length */ - if (keylen > (AES_MAX_KEY_SIZE / 8)) { - return BAD_FUNC_ARG; - } - #endif - - return wc_AesSetKeyLocal(aes, userKey, keylen, iv, dir); + return wc_AesSetKeyLocal(aes, userKey, keylen, iv, dir, 1); } #if defined(WOLFSSL_AES_DIRECT) || defined(WOLFSSL_AES_COUNTER) @@ -2809,10 +2819,7 @@ static void wc_AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) int wc_AesSetKeyDirect(Aes* aes, const byte* userKey, word32 keylen, const byte* iv, int dir) { - if (aes == NULL) { - return BAD_FUNC_ARG; - } - return wc_AesSetKeyLocal(aes, userKey, keylen, iv, dir); + return wc_AesSetKeyLocal(aes, userKey, keylen, iv, dir, 0); } #endif /* WOLFSSL_AES_DIRECT || WOLFSSL_AES_COUNTER */ #endif /* wc_AesSetKey block */