From 23c5678797007ceb6057b90c29dee481eb729e45 Mon Sep 17 00:00:00 2001 From: Sean Parkinson Date: Fri, 21 Nov 2025 11:18:38 +1000 Subject: [PATCH] RSA decrypt: don't write past buffer end on error When the decrypted data is bigger than the buffer, the one extra bytes was being written to. --- tests/api/test_rsa.c | 8 ++++++++ wolfcrypt/src/rsa.c | 25 +++++++++++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/tests/api/test_rsa.c b/tests/api/test_rsa.c index 40786e467..1a1236aff 100644 --- a/tests/api/test_rsa.c +++ b/tests/api/test_rsa.c @@ -789,15 +789,18 @@ int test_wc_RsaPublicEncryptDecrypt(void) WC_DECLARE_VAR(in, byte, TEST_STRING_SZ, NULL); WC_DECLARE_VAR(plain, byte, TEST_STRING_SZ, NULL); WC_DECLARE_VAR(cipher, byte, TEST_RSA_BYTES, NULL); + WC_DECLARE_VAR(shortPlain, byte, TEST_STRING_SZ - 4, NULL); WC_ALLOC_VAR(in, byte, TEST_STRING_SZ, NULL); WC_ALLOC_VAR(plain, byte, TEST_STRING_SZ, NULL); WC_ALLOC_VAR(cipher, byte, TEST_RSA_BYTES, NULL); + WC_ALLOC_VAR(shortPlain, byte, TEST_STRING_SZ - 4, NULL); #ifdef WC_DECLARE_VAR_IS_HEAP_ALLOC ExpectNotNull(in); ExpectNotNull(plain); ExpectNotNull(cipher); + ExpectNotNull(shortPlain); #endif ExpectNotNull(XMEMCPY(in, inStr, inLen)); @@ -824,6 +827,11 @@ int test_wc_RsaPublicEncryptDecrypt(void) ExpectIntEQ(XMEMCMP(plain, inStr, plainLen), 0); /* Pass bad args - tested in another testing function.*/ + /* Test for when plain length is less than required. */ + ExpectIntEQ(wc_RsaPrivateDecrypt(cipher, cipherLenResult, shortPlain, + TEST_STRING_SZ - 4, &key), RSA_BUFFER_E); + + WC_FREE_VAR(shortPlain, NULL); WC_FREE_VAR(in, NULL); WC_FREE_VAR(plain, NULL); WC_FREE_VAR(cipher, NULL); diff --git a/wolfcrypt/src/rsa.c b/wolfcrypt/src/rsa.c index c870dba58..5e7a98ccd 100644 --- a/wolfcrypt/src/rsa.c +++ b/wolfcrypt/src/rsa.c @@ -3636,15 +3636,28 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out, if (rsa_type == RSA_PRIVATE_DECRYPT) { word32 i = 0; word32 j; + byte last = 0; int start = (int)((size_t)pad - (size_t)key->data); for (j = 0; j < key->dataLen; j++) { - signed char c; - out[i] = key->data[j]; - c = (signed char)ctMaskGTE((int)j, start); - c &= (signed char)ctMaskLT((int)i, (int)outLen); - /* 0 - no add, -1 add */ - i += (word32)((byte)(-c)); + signed char incMask; + signed char maskData; + + /* When j < start + outLen then out[i] = key->data[j] + * else out[i] = last + */ + maskData = (signed char)ctMaskLT((int)j, + start + (int)outLen); + out[i] = (byte)(key->data[j] & maskData ) | + (byte)(last & (~maskData)); + last = out[i]; + + /* Increment i when j is in range: + * [start..(start + outLen - 1)]. */ + incMask = (signed char)ctMaskGTE((int)j, start); + incMask &= (signed char)ctMaskLT((int)j, + start + (int)outLen - 1); + i += (word32)((byte)(-incMask)); } } else