diff --git a/wolfcrypt/src/chacha20_poly1305.c b/wolfcrypt/src/chacha20_poly1305.c index 4d93f3643..1657f188d 100644 --- a/wolfcrypt/src/chacha20_poly1305.c +++ b/wolfcrypt/src/chacha20_poly1305.c @@ -329,22 +329,14 @@ int wc_ChaCha20Poly1305_UpdateAad(ChaChaPoly_Aead* aead, static int wc_ChaCha20Poly1305_CalcAad(ChaChaPoly_Aead* aead) { + /* Pad the AAD to 16 bytes */ int ret = 0; - if (aead == NULL) { - return BAD_FUNC_ARG; - } - if (aead->state == CHACHA20_POLY1305_STATE_AAD) { - /* Pad the AAD to 16 bytes */ - byte padding[CHACHA20_POLY1305_MAC_PADDING_ALIGNMENT - 1]; - word32 paddingLen = -(int)aead->aadLen & - (CHACHA20_POLY1305_MAC_PADDING_ALIGNMENT - 1); - if (paddingLen > 0) { - XMEMSET(padding, 0, paddingLen); - ret = wc_Poly1305Update(&aead->poly, padding, paddingLen); - } - - /* advance state */ - aead->state = CHACHA20_POLY1305_STATE_DATA; + byte padding[CHACHA20_POLY1305_MAC_PADDING_ALIGNMENT - 1]; + word32 paddingLen = -(int)aead->aadLen & + (CHACHA20_POLY1305_MAC_PADDING_ALIGNMENT - 1); + if (paddingLen > 0) { + XMEMSET(padding, 0, paddingLen); + ret = wc_Poly1305Update(&aead->poly, padding, paddingLen); } return ret; } @@ -364,7 +356,12 @@ int wc_ChaCha20Poly1305_UpdateData(ChaChaPoly_Aead* aead, } /* calculate AAD */ - ret = wc_ChaCha20Poly1305_CalcAad(aead); + if (aead->state == CHACHA20_POLY1305_STATE_AAD) { + ret = wc_ChaCha20Poly1305_CalcAad(aead); + } + + /* advance state */ + aead->state = CHACHA20_POLY1305_STATE_DATA; if (ret == 0) { /* Perform ChaCha20 encrypt or decrypt inline and Poly1305 auth calc */ @@ -393,12 +390,14 @@ int wc_ChaCha20Poly1305_Final(ChaChaPoly_Aead* aead, if (aead == NULL || outAuthTag == NULL) { return BAD_FUNC_ARG; } + if (aead->state != CHACHA20_POLY1305_STATE_AAD && + aead->state != CHACHA20_POLY1305_STATE_DATA) { + return BAD_STATE_E; + } /* make sure AAD is calculated */ - ret = wc_ChaCha20Poly1305_CalcAad(aead); - - if (aead->state != CHACHA20_POLY1305_STATE_DATA) { - return BAD_STATE_E; + if (aead->state == CHACHA20_POLY1305_STATE_AAD) { + ret = wc_ChaCha20Poly1305_CalcAad(aead); } /* Pad the ciphertext to 16 bytes */