RSA decrypt: don't write past buffer end on error

When the decrypted data is bigger than the buffer, the one extra bytes
was being written to.
This commit is contained in:
Sean Parkinson
2025-11-21 11:18:38 +10:00
parent 59f4fa5686
commit 23c5678797
2 changed files with 27 additions and 6 deletions

View File

@@ -789,15 +789,18 @@ int test_wc_RsaPublicEncryptDecrypt(void)
WC_DECLARE_VAR(in, byte, TEST_STRING_SZ, NULL);
WC_DECLARE_VAR(plain, byte, TEST_STRING_SZ, NULL);
WC_DECLARE_VAR(cipher, byte, TEST_RSA_BYTES, NULL);
WC_DECLARE_VAR(shortPlain, byte, TEST_STRING_SZ - 4, NULL);
WC_ALLOC_VAR(in, byte, TEST_STRING_SZ, NULL);
WC_ALLOC_VAR(plain, byte, TEST_STRING_SZ, NULL);
WC_ALLOC_VAR(cipher, byte, TEST_RSA_BYTES, NULL);
WC_ALLOC_VAR(shortPlain, byte, TEST_STRING_SZ - 4, NULL);
#ifdef WC_DECLARE_VAR_IS_HEAP_ALLOC
ExpectNotNull(in);
ExpectNotNull(plain);
ExpectNotNull(cipher);
ExpectNotNull(shortPlain);
#endif
ExpectNotNull(XMEMCPY(in, inStr, inLen));
@@ -824,6 +827,11 @@ int test_wc_RsaPublicEncryptDecrypt(void)
ExpectIntEQ(XMEMCMP(plain, inStr, plainLen), 0);
/* Pass bad args - tested in another testing function.*/
/* Test for when plain length is less than required. */
ExpectIntEQ(wc_RsaPrivateDecrypt(cipher, cipherLenResult, shortPlain,
TEST_STRING_SZ - 4, &key), RSA_BUFFER_E);
WC_FREE_VAR(shortPlain, NULL);
WC_FREE_VAR(in, NULL);
WC_FREE_VAR(plain, NULL);
WC_FREE_VAR(cipher, NULL);

View File

@@ -3636,15 +3636,28 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
if (rsa_type == RSA_PRIVATE_DECRYPT) {
word32 i = 0;
word32 j;
byte last = 0;
int start = (int)((size_t)pad - (size_t)key->data);
for (j = 0; j < key->dataLen; j++) {
signed char c;
out[i] = key->data[j];
c = (signed char)ctMaskGTE((int)j, start);
c &= (signed char)ctMaskLT((int)i, (int)outLen);
/* 0 - no add, -1 add */
i += (word32)((byte)(-c));
signed char incMask;
signed char maskData;
/* When j < start + outLen then out[i] = key->data[j]
* else out[i] = last
*/
maskData = (signed char)ctMaskLT((int)j,
start + (int)outLen);
out[i] = (byte)(key->data[j] & maskData ) |
(byte)(last & (~maskData));
last = out[i];
/* Increment i when j is in range:
* [start..(start + outLen - 1)]. */
incMask = (signed char)ctMaskGTE((int)j, start);
incMask &= (signed char)ctMaskLT((int)j,
start + (int)outLen - 1);
i += (word32)((byte)(-incMask));
}
}
else