diff --git a/wolfcrypt/src/aes.c b/wolfcrypt/src/aes.c index 150a84ab4..27f0ab52f 100755 --- a/wolfcrypt/src/aes.c +++ b/wolfcrypt/src/aes.c @@ -463,6 +463,18 @@ } CRYP_KeyInit(&AES_CRYP_KeyInitStructure); + /* set direction, key, and datatype */ + AES_CRYP_InitStructure.CRYP_AlgoDir = CRYP_AlgoDir_Decrypt; + AES_CRYP_InitStructure.CRYP_AlgoMode = CRYP_AlgoMode_AES_Key; + AES_CRYP_InitStructure.CRYP_DataType = CRYP_DataType_8b; + CRYP_Init(&AES_CRYP_InitStructure); + + /* enable crypto processor */ + CRYP_Cmd(ENABLE); + + /* wait until decrypt key has been intialized */ + while (CRYP_GetFlagStatus(CRYP_FLAG_BUSY) != RESET) {} + /* set direction, mode, and datatype */ AES_CRYP_InitStructure.CRYP_AlgoDir = CRYP_AlgoDir_Decrypt; AES_CRYP_InitStructure.CRYP_AlgoMode = CRYP_AlgoMode_AES_ECB; @@ -1800,6 +1812,9 @@ static void wc_AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) #ifndef WOLFSSL_STM32_CUBEMX ByteReverseWords(rk, rk, keylen); #endif + #ifdef WOLFSSL_AES_COUNTER + aes->left = 0; + #endif return wc_AesSetIV(aes, iv); } @@ -1881,10 +1896,9 @@ static void wc_AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) aes->rounds = keylen/4 + 6; XMEMCPY(aes->key, userKey, keylen); - - #ifdef WOLFSSL_AES_COUNTER - aes->left = 0; - #endif /* WOLFSSL_AES_COUNTER */ + #ifdef WOLFSSL_AES_COUNTER + aes->left = 0; + #endif return wc_AesSetIV(aes, iv); } @@ -1909,10 +1923,9 @@ static void wc_AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) if (rk == NULL) return BAD_FUNC_ARG; - #ifdef WOLFSSL_AES_COUNTER - aes->left = 0; - #endif /* WOLFSSL_AES_COUNTER */ - + #ifdef WOLFSSL_AES_COUNTER + aes->left = 0; + #endif aes->keylen = keylen; aes->rounds = keylen/4 + 6; @@ -3000,17 +3013,41 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz) #endif /* AES-CTR */ -#if defined(WOLFSSL_AES_COUNTER) || (defined(HAVE_AESGCM_DECRYPT) && defined(STM32_CRYPTO)) +#if defined(WOLFSSL_AES_COUNTER) + + #ifndef FREESCALE_LTC /* LTC doesn't need soft counter */ + /* Increment AES counter */ + static INLINE void IncrementAesCounter(byte* inOutCtr) + { + /* in network byte order so start at end and work back */ + int i; + for (i = AES_BLOCK_SIZE - 1; i >= 0; i--) { + if (++inOutCtr[i]) /* we're done unless we overflow */ + return; + } + } + #endif + #ifdef STM32_CRYPTO #ifdef WOLFSSL_STM32_CUBEMX int wc_AesCtrEncrypt(Aes* aes, byte* out, const byte* in, word32 sz) { + int ret = 0; CRYP_HandleTypeDef hcryp; + byte* tmp; if (aes == NULL || out == NULL || in == NULL) { return BAD_FUNC_ARG; } + /* consume any unused bytes left in aes->tmp */ + tmp = (byte*)aes->tmp + AES_BLOCK_SIZE - aes->left; + while (aes->left && sz) { + *(out++) = *(in++) ^ *(tmp++); + aes->left--; + sz--; + } + XMEMSET(&hcryp, 0, sizeof(CRYP_HandleTypeDef)); switch (aes->rounds) { case 10: /* 128-bit key */ @@ -3035,24 +3072,34 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz) if (HAL_CRYP_AESCTR_Encrypt(&hcryp, (byte*)in, sz, out, STM32_HAL_TIMEOUT) != HAL_OK) { /* failed */ + ret = WC_TIMEOUT_E; } + IncrementAesCounter((byte*)aes->reg); HAL_CRYP_DeInit(&hcryp); - return 0; + return ret; } #else int wc_AesCtrEncrypt(Aes* aes, byte* out, const byte* in, word32 sz) { word32 *enc_key, *iv; - int len = (int)sz; + byte* tmp; + CRYP_InitTypeDef AES_CRYP_InitStructure; + CRYP_KeyInitTypeDef AES_CRYP_KeyInitStructure; + CRYP_IVInitTypeDef AES_CRYP_IVInitStructure; if (aes == NULL || out == NULL || in == NULL) { return BAD_FUNC_ARG; } - CRYP_InitTypeDef AES_CRYP_InitStructure; - CRYP_KeyInitTypeDef AES_CRYP_KeyInitStructure; - CRYP_IVInitTypeDef AES_CRYP_IVInitStructure; + + /* consume any unused bytes left in aes->tmp */ + tmp = (byte*)aes->tmp + AES_BLOCK_SIZE - aes->left; + while (aes->left && sz) { + *(out++) = *(in++) ^ *(tmp++); + aes->left--; + sz--; + } enc_key = aes->key; iv = aes->reg; @@ -3119,7 +3166,7 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz) /* enable crypto processor */ CRYP_Cmd(ENABLE); - while (len > 0) { + while (sz >= AES_BLOCK_SIZE) { /* flush IN/OUT FIFOs */ CRYP_FIFOFlush(); @@ -3136,27 +3183,37 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz) *(uint32_t*)&out[8] = CRYP_DataOut(); *(uint32_t*)&out[12] = CRYP_DataOut(); - /* store iv for next call */ - XMEMCPY(aes->reg, out + len - AES_BLOCK_SIZE, AES_BLOCK_SIZE); + IncrementAesCounter((byte*)aes->reg); - len -= AES_BLOCK_SIZE; + sz -= AES_BLOCK_SIZE; in += AES_BLOCK_SIZE; out += AES_BLOCK_SIZE; + aes->left = 0; } /* disable crypto processor */ CRYP_Cmd(DISABLE); + + /* handle non block size remaining and store unused byte count in left */ + if (sz) { + wc_AesEncrypt(aes, (byte*)aes->reg, (byte*)aes->tmp); + IncrementAesCounter((byte*)aes->reg); + + aes->left = AES_BLOCK_SIZE; + tmp = (byte*)aes->tmp; + + while (sz--) { + *(out++) = *(in++) ^ *(tmp++); + aes->left--; + } + } + + return 0; } #endif /* WOLFSSL_STM32_CUBEMX */ #elif defined(WOLFSSL_PIC32MZ_CRYPT) - static void Pic32AesIncIV(Aes* aes) { - int i; - for (i = AES_BLOCK_SIZE - 1; i >= 0; i--) { - if (++((byte *)aes->iv_ce)[i]) - break; - } - } + int wc_AesCtrEncrypt(Aes* aes, byte* out, const byte* in, word32 sz) { int ret = 0; @@ -3184,7 +3241,7 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz) XMEMCPY(out, out_block + aes->left, odd); aes->left = 0; XMEMSET(tmp, 0x0, AES_BLOCK_SIZE); - Pic32AesIncIV(aes); + IncrementAesCounter((byte*)aes->reg); } in += odd; out+= odd; @@ -3201,7 +3258,7 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz) out += even; in += even; do { - Pic32AesIncIV(aes); + IncrementAesCounter((byte*)aes->reg); even -= AES_BLOCK_SIZE; } while (even > 0); } @@ -3235,9 +3292,9 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz) if (aes == NULL || out == NULL || in == NULL) { return BAD_FUNC_ARG; } - tmp = (byte*)aes->tmp + AES_BLOCK_SIZE - aes->left; /* consume any unused bytes left in aes->tmp */ + tmp = (byte*)aes->tmp + AES_BLOCK_SIZE - aes->left; while (aes->left && sz) { *(out++) = *(in++) ^ *(tmp++); aes->left--; @@ -3252,24 +3309,13 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz) LTC_AES_CryptCtr(LTC_BASE, in, out, sz, iv, enc_key, keySize, (byte*)aes->tmp, - (uint32_t*)&(aes->left)); + (uint32_t*)&aes->left); } return 0; } #else - /* Increment AES counter */ - static INLINE void IncrementAesCounter(byte* inOutCtr) - { - int i; - - /* in network byte order so start at end and work back */ - for (i = AES_BLOCK_SIZE - 1; i >= 0; i--) { - if (++inOutCtr[i]) /* we're done unless we overflow */ - return; - } - } int wc_AesCtrEncrypt(Aes* aes, byte* out, const byte* in, word32 sz) { @@ -3278,9 +3324,9 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz) if (aes == NULL || out == NULL || in == NULL) { return BAD_FUNC_ARG; } - tmp = (byte*)aes->tmp + AES_BLOCK_SIZE - aes->left; /* consume any unused bytes left in aes->tmp */ + tmp = (byte*)aes->tmp + AES_BLOCK_SIZE - aes->left; while (aes->left && sz) { *(out++) = *(in++) ^ *(tmp++); aes->left--; @@ -6969,7 +7015,7 @@ int wc_AesGcmEncrypt(Aes* aes, byte* out, const byte* in, word32 sz, #else -#if defined(WOLFSSL_STM32F4) || defined(WOLFSSL_STM32F7) +#if defined(STM32_CRYPTO) && (defined(WOLFSSL_STM32F4) || defined(WOLFSSL_STM32F7)) /* additional argument checks - STM32 HW only supports 12 byte IV */ if (ivSz != NONCE_SZ) {