diff --git a/src/ssl.c b/src/ssl.c index 295eadc3b..4913f3061 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -54552,75 +54552,72 @@ int wolfSSL_mask_bits(WOLFSSL_BIGNUM* bn, int n) /* WOLFSSL_SUCCESS on ok */ int wolfSSL_BN_rand(WOLFSSL_BIGNUM* bn, int bits, int top, int bottom) { - int ret = WOLFSSL_FAILURE; - int len; - int initTmpRng = 0; - WC_RNG* rng = NULL; -#ifdef WOLFSSL_SMALL_STACK - WC_RNG* tmpRNG = NULL; - byte* buff = NULL; -#else - WC_RNG tmpRNG[1]; - byte buff[1024]; -#endif + int ret = WOLFSSL_SUCCESS; + int len = (bits + 7) / 8; + WC_RNG* rng = &globalRNG; + byte* buff = NULL; - (void)top; - (void)bottom; - WOLFSSL_MSG("wolfSSL_BN_rand"); + WOLFSSL_ENTER("wolfSSL_BN_rand"); - if (bits <= 0) { - return WOLFSSL_FAILURE; + if ((bn == NULL || bn->internal == NULL) || bits < 0 || + (bits == 0 && (bottom != 0 || top != -1)) || (bits == 1 && top > 0)) { + WOLFSSL_MSG("Bad argument"); + ret = WOLFSSL_FAILURE; } - len = bits / 8; - if (bits % 8) - len++; - - /* has to be a length of at least 1 since we set buf[0] and buf[len-1] */ - if (len < 1) { - return WOLFSSL_FAILURE; - } - -#ifdef WOLFSSL_SMALL_STACK - buff = (byte*)XMALLOC(1024, NULL, DYNAMIC_TYPE_TMP_BUFFER); - tmpRNG = (WC_RNG*) XMALLOC(sizeof(WC_RNG), NULL, DYNAMIC_TYPE_RNG); - if (buff == NULL || tmpRNG == NULL) { - XFREE(buff, NULL, DYNAMIC_TYPE_TMP_BUFFER); - XFREE(tmpRNG, NULL, DYNAMIC_TYPE_RNG); - return ret; - } -#endif - - if (bn == NULL || bn->internal == NULL) - WOLFSSL_MSG("Bad function arguments"); - else if (wc_InitRng(tmpRNG) == 0) { - rng = tmpRNG; - initTmpRng = 1; - } - else if (initGlobalRNG) - rng = &globalRNG; - - if (rng) { - if (wc_RNG_GenerateBlock(rng, buff, len) != 0) - WOLFSSL_MSG("Bad wc_RNG_GenerateBlock"); - else { - buff[0] |= 0x80 | 0x40; - buff[len-1] |= 0x01; - - if (mp_read_unsigned_bin((mp_int*)bn->internal,buff,len) != MP_OKAY) - WOLFSSL_MSG("mp read bin failed"); - else - ret = WOLFSSL_SUCCESS; + if (ret == WOLFSSL_SUCCESS) { + buff = (byte*)XMALLOC(len, NULL, DYNAMIC_TYPE_TMP_BUFFER); + if (buff == NULL) { + WOLFSSL_MSG("Failed to allocate buffer."); + XFREE(buff, NULL, DYNAMIC_TYPE_TMP_BUFFER); + ret = WOLFSSL_FAILURE; } } - if (initTmpRng) - wc_FreeRng(tmpRNG); + if (ret == WOLFSSL_SUCCESS && initGlobalRNG == 0 && + wolfSSL_RAND_Init() != WOLFSSL_SUCCESS) { + WOLFSSL_MSG("Failed to use global RNG."); + ret = WOLFSSL_FAILURE; + } -#ifdef WOLFSSL_SMALL_STACK - XFREE(buff, NULL, DYNAMIC_TYPE_TMP_BUFFER); - XFREE(tmpRNG, NULL, DYNAMIC_TYPE_RNG); -#endif + if (ret == WOLFSSL_SUCCESS && wc_RNG_GenerateBlock(rng, buff, len) != 0) { + WOLFSSL_MSG("wc_RNG_GenerateBlock failed"); + ret = WOLFSSL_FAILURE; + } + if (ret == WOLFSSL_SUCCESS && + mp_read_unsigned_bin((mp_int*)bn->internal,buff,len) != MP_OKAY) { + WOLFSSL_MSG("mp_read_unsigned_bin failed"); + ret = WOLFSSL_FAILURE; + } + if (ret == WOLFSSL_SUCCESS) { + /* Truncate to requested bit length. */ + mp_rshb((mp_int*)bn->internal, 8 - (bits % 8)); + + if (top == 0) { + if (mp_set_bit((mp_int*)bn->internal, bits - 1) != MP_OKAY) { + WOLFSSL_MSG("Failed to set top bit"); + ret = WOLFSSL_FAILURE; + } + } + else if (top > 0) { + if (mp_set_bit((mp_int*)bn->internal, bits - 1) != MP_OKAY || + mp_set_bit((mp_int*)bn->internal, bits - 2) != MP_OKAY) { + WOLFSSL_MSG("Failed to set top 2 bits"); + ret = WOLFSSL_FAILURE; + } + } + } + if (ret == WOLFSSL_SUCCESS && bottom && + mp_set_bit((mp_int*)bn->internal, 0) != MP_OKAY) { + WOLFSSL_MSG("Failed to set 0th bit"); + ret = WOLFSSL_FAILURE; + } + + if (buff != NULL) { + XFREE(buff, NULL, DYNAMIC_TYPE_TMP_BUFFER); + } + + WOLFSSL_LEAVE("wolfSSL_BN_rand", ret); return ret; } diff --git a/tests/api.c b/tests/api.c index 68b8dd6a4..8484d8f1c 100644 --- a/tests/api.c +++ b/tests/api.c @@ -36938,25 +36938,49 @@ static void test_wolfSSL_RAND_bytes(void) static void test_wolfSSL_BN_rand(void) { - #if defined(OPENSSL_EXTRA) +#if defined(OPENSSL_EXTRA) BIGNUM* bn; BIGNUM* range; printf(testingFmt, "wolfSSL_BN_rand()"); + /* Error conditions. */ + /* NULL BN. */ + AssertIntEQ(BN_rand(NULL, 0, 0, 0), SSL_FAILURE); AssertNotNull(bn = BN_new()); - AssertIntNE(BN_rand(bn, 0, 0, 0), SSL_SUCCESS); - BN_free(bn); + /* Negative bits. */ + AssertIntEQ(BN_rand(bn, -2, 0, 0), SSL_FAILURE); + /* 0 bits and top is not -1. */ + AssertIntEQ(BN_rand(bn, 0, 1, 0), SSL_FAILURE); + /* 0 bits and bottom is not 0. */ + AssertIntEQ(BN_rand(bn, 0, 0, 1), SSL_FAILURE); + /* 1 bit and top is 1. */ + AssertIntEQ(BN_rand(bn, 1, 1, 0), SSL_FAILURE); + + AssertIntEQ(BN_rand(bn, 0, -1, 0), SSL_SUCCESS); + AssertIntEQ(BN_num_bits(bn), 0); - AssertNotNull(bn = BN_new()); AssertIntEQ(BN_rand(bn, 8, 0, 0), SSL_SUCCESS); - BN_free(bn); + AssertIntEQ(BN_num_bits(bn), 8); + /* When top is 0, top bit should be 1. */ + AssertIntEQ(BN_is_bit_set(bn, 7), SSL_SUCCESS); - AssertNotNull(bn = BN_new()); - AssertIntEQ(BN_rand(bn, 64, 0, 0), SSL_SUCCESS); - BN_free(bn); + AssertIntEQ(BN_rand(bn, 8, 1, 0), SSL_SUCCESS); + /* When top is 1, top 2 bits should be 1. */ + AssertIntEQ(BN_is_bit_set(bn, 7), SSL_SUCCESS); + AssertIntEQ(BN_is_bit_set(bn, 6), SSL_SUCCESS); + + AssertIntEQ(BN_rand(bn, 8, 0, 1), SSL_SUCCESS); + /* When bottom is 1, bottom bit should be 1. */ + AssertIntEQ(BN_is_bit_set(bn, 0), SSL_SUCCESS); + + /* Regression test: Older versions of wolfSSL_BN_rand would round the + * requested number of bits up to the nearest multiple of 8. E.g. in this + * case, requesting a 13-bit random number would actually return a 16-bit + * random number. */ + AssertIntEQ(BN_rand(bn, 13, 0, 0), SSL_SUCCESS); + AssertIntEQ(BN_num_bits(bn), 13); - AssertNotNull(bn = BN_new()); AssertNotNull(range = BN_new()); AssertIntEQ(BN_rand(range, 64, 0, 0), SSL_SUCCESS); AssertIntEQ(BN_rand_range(bn, range), SSL_SUCCESS); @@ -36964,7 +36988,7 @@ static void test_wolfSSL_BN_rand(void) BN_free(range); printf(resultFmt, passed); - #endif +#endif } static void test_wolfSSL_pseudo_rand(void)