diff --git a/src/internal.c b/src/internal.c index 8b7807152..787d10b7f 100644 --- a/src/internal.c +++ b/src/internal.c @@ -19108,11 +19108,12 @@ int SendClientKeyExchange(WOLFSSL* ssl) goto exit_scke; } - ret = wc_DhSetKey(ssl->buffers.serverDH_Key, + ret = wc_DhSetCheckKey(ssl->buffers.serverDH_Key, ssl->buffers.serverDH_P.buffer, ssl->buffers.serverDH_P.length, ssl->buffers.serverDH_G.buffer, - ssl->buffers.serverDH_G.length); + ssl->buffers.serverDH_G.length, + NULL, 0, 0, ssl->rng); if (ret != 0) { goto exit_scke; } @@ -19203,11 +19204,12 @@ int SendClientKeyExchange(WOLFSSL* ssl) goto exit_scke; } - ret = wc_DhSetKey(ssl->buffers.serverDH_Key, + ret = wc_DhSetCheckKey(ssl->buffers.serverDH_Key, ssl->buffers.serverDH_P.buffer, ssl->buffers.serverDH_P.length, ssl->buffers.serverDH_G.buffer, - ssl->buffers.serverDH_G.length); + ssl->buffers.serverDH_G.length, + NULL, 0, 0, ssl->rng); if (ret != 0) { goto exit_scke; } @@ -20917,11 +20919,12 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, goto exit_sske; } - ret = wc_DhSetKey(ssl->buffers.serverDH_Key, + ret = wc_DhSetCheckKey(ssl->buffers.serverDH_Key, ssl->buffers.serverDH_P.buffer, ssl->buffers.serverDH_P.length, ssl->buffers.serverDH_G.buffer, - ssl->buffers.serverDH_G.length); + ssl->buffers.serverDH_G.length, + NULL, 0, 1, ssl->rng); if (ret != 0) { goto exit_sske; } @@ -24447,11 +24450,12 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, goto exit_dcke; } - ret = wc_DhSetKey(ssl->buffers.serverDH_Key, + ret = wc_DhSetCheckKey(ssl->buffers.serverDH_Key, ssl->buffers.serverDH_P.buffer, ssl->buffers.serverDH_P.length, ssl->buffers.serverDH_G.buffer, - ssl->buffers.serverDH_G.length); + ssl->buffers.serverDH_G.length, + NULL, 0, 1, ssl->rng); /* set the max agree result size */ ssl->arrays->preMasterSz = ENCRYPT_LEN; @@ -24503,11 +24507,12 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, goto exit_dcke; } - ret = wc_DhSetKey(ssl->buffers.serverDH_Key, + ret = wc_DhSetCheckKey(ssl->buffers.serverDH_Key, ssl->buffers.serverDH_P.buffer, ssl->buffers.serverDH_P.length, ssl->buffers.serverDH_G.buffer, - ssl->buffers.serverDH_G.length); + ssl->buffers.serverDH_G.length, + NULL, 0, 1, ssl->rng); break; } diff --git a/wolfcrypt/src/integer.c b/wolfcrypt/src/integer.c index 819c4789f..13a1ab5c1 100644 --- a/wolfcrypt/src/integer.c +++ b/wolfcrypt/src/integer.c @@ -4585,7 +4585,7 @@ int mp_prime_is_prime_ex (mp_int * a, int t, int *result, WC_RNG *rng) } baseSz = mp_count_bits(a); - baseSz = (baseSz / 8) + (baseSz % 8) ? 1 : 0; + baseSz = (baseSz / 8) + ((baseSz % 8) ? 1 : 0); base = (byte*)XMALLOC(baseSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); if (base == NULL) { diff --git a/wolfcrypt/src/tfm.c b/wolfcrypt/src/tfm.c index aef6cf2b9..3e4a8e8d0 100644 --- a/wolfcrypt/src/tfm.c +++ b/wolfcrypt/src/tfm.c @@ -2962,7 +2962,7 @@ int mp_prime_is_prime_ex(mp_int* a, int t, int* result, WC_RNG* rng) #endif baseSz = fp_count_bits(a); - baseSz = (baseSz / 8) + (baseSz % 8) ? 1 : 0; + baseSz = (baseSz / 8) + ((baseSz % 8) ? 1 : 0); #ifdef WOLFSSL_SMALL_STACK base = (byte*)XMALLOC(baseSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); diff --git a/wolfcrypt/test/test.c b/wolfcrypt/test/test.c index 65c141022..eef29de76 100644 --- a/wolfcrypt/test/test.c +++ b/wolfcrypt/test/test.c @@ -331,6 +331,9 @@ int memory_test(void); #ifdef HAVE_VALGRIND int mp_test(void); #endif +#ifdef WOLFSSL_PUBLIC_MP +int prime_test(void); +#endif #ifdef ASN_BER_TO_DER int berder_test(void); #endif @@ -928,6 +931,13 @@ initDefaultName(); printf( "mp test passed!\n"); #endif +#ifdef WOLFSSL_PUBLIC_MP + if ( (ret = prime_test()) != 0) + return err_sys("prime test failed!\n", ret); + else + printf( "prime test passed!\n"); +#endif + #if defined(ASN_BER_TO_DER) && \ (defined(WOLFSSL_TEST_CERT) || defined(OPENSSL_EXTRA) || \ defined(OPENSSL_EXTRA_X509_SMALL)) @@ -19131,6 +19141,189 @@ done: } #endif + +#ifdef WOLFSSL_PUBLIC_MP + +typedef struct pairs_t { + const unsigned char* coeff; + int coeffSz; + int exp; +} pairs_t; + + +/* +n =p1p2p3, where pi = ki(p1−1)+1 with (k2,k3) = (173,293) +p1 = 2^192 * 0x000000000000e24fd4f6d6363200bf2323ec46285cac1d3a + + 2^0 * 0x0b2488b0c29d96c5e67f8bec15b54b189ae5636efe89b45b +*/ + +const unsigned char c192a[] = +{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xe2, 0x4f, + 0xd4, 0xf6, 0xd6, 0x36, 0x32, 0x00, 0xbf, 0x23, + 0x23, 0xec, 0x46, 0x28, 0x5c, 0xac, 0x1d, 0x3a +}; +const unsigned char c0a[] = +{ + 0x0b, 0x24, 0x88, 0xb0, 0xc2, 0x9d, 0x96, 0xc5, + 0xe6, 0x7f, 0x8b, 0xec, 0x15, 0xb5, 0x4b, 0x18, + 0x9a, 0xe5, 0x63, 0x6e, 0xfe, 0x89, 0xb4, 0x5b +}; + +const pairs_t ecPairsA[] = +{ + {c192a, sizeof(c192a), 192}, + {c0a, sizeof(c0a), 0} +}; + +const int kA[] = {173, 293}; + +const unsigned char controlPrime[] = { + 0xe1, 0x76, 0x45, 0x80, 0x59, 0xb6, 0xd3, 0x49, + 0xdf, 0x0a, 0xef, 0x12, 0xd6, 0x0f, 0xf0, 0xb7, + 0xcb, 0x2a, 0x37, 0xbf, 0xa7, 0xf8, 0xb5, 0x4d, + 0xf5, 0x31, 0x35, 0xad, 0xe4, 0xa3, 0x94, 0xa1, + 0xdb, 0xf1, 0x96, 0xad, 0xb5, 0x05, 0x64, 0x85, + 0x83, 0xfc, 0x1b, 0x5b, 0x29, 0xaa, 0xbe, 0xf8, + 0x26, 0x3f, 0x76, 0x7e, 0xad, 0x1c, 0xf0, 0xcb, + 0xd7, 0x26, 0xb4, 0x1b, 0x05, 0x8e, 0x56, 0x86, + 0x7e, 0x08, 0x62, 0x21, 0xc1, 0x86, 0xd6, 0x47, + 0x79, 0x3e, 0xb7, 0x5d, 0xa4, 0xc6, 0x3a, 0xd7, + 0xb1, 0x74, 0x20, 0xf6, 0x50, 0x97, 0x41, 0x04, + 0x53, 0xed, 0x3f, 0x26, 0xd6, 0x6f, 0x91, 0xfa, + 0x68, 0x26, 0xec, 0x2a, 0xdc, 0x9a, 0xf1, 0xe7, + 0xdc, 0xfb, 0x73, 0xf0, 0x79, 0x43, 0x1b, 0x21, + 0xa3, 0x59, 0x04, 0x63, 0x52, 0x07, 0xc9, 0xd7, + 0xe6, 0xd1, 0x1b, 0x5d, 0x5e, 0x96, 0xfa, 0x53 +}; + + +static int GenerateNextP(mp_int* p1, mp_int* p2, int k) +{ + int ret; + + ret = mp_sub_d(p1, 1, p2); + if (ret == 0) + ret = mp_mul_d(p2, k, p2); + if (ret == 0) + ret = mp_add_d(p2, 1, p2); + + return ret; +} + + +static int GenerateP(mp_int* p1, mp_int* p2, mp_int* p3, + const pairs_t* ecPairs, int ecPairsSz, + const int* k) +{ + mp_int x,y; + int ret, i; + + ret = mp_init(&x); + if (ret == 0) { + ret = mp_init(&y); + if (ret != 0) { + mp_clear(&x); + return MP_MEM; + } + } + for (i = 0; ret == 0 && i < ecPairsSz; i++) { + ret = mp_read_unsigned_bin(&x, ecPairs[i].coeff, ecPairs[i].coeffSz); + /* p1 = 2^exp */ + if (ret == 0) + ret = mp_2expt(&y, ecPairs[i].exp); + /* p1 = p1 * m */ + if (ret == 0) + ret = mp_mul(&x, &y, &x); + /* p1 += */ + if (ret == 0) + ret = mp_add(p1, &x, p1); + mp_zero(&x); + mp_zero(&y); + } + mp_clear(&x); + mp_clear(&y); + + if (ret == 0) + ret = GenerateNextP(p1, p2, k[0]); + if (ret == 0) + ret = GenerateNextP(p1, p3, k[1]); + + return ret; +} + +int prime_test(void) +{ + mp_int n, p1, p2, p3; + int ret, isPrime = 0; + WC_RNG rng; + + ret = wc_InitRng(&rng); + if (ret == 0) + ret = mp_init_multi(&n, &p1, &p2, &p3, NULL, NULL); + if (ret == 0) + ret = GenerateP(&p1, &p2, &p3, + ecPairsA, sizeof(ecPairsA) / sizeof(ecPairsA[0]), kA); + if (ret == 0) + ret = mp_mul(&p1, &p2, &n); + if (ret == 0) + ret = mp_mul(&n, &p3, &n); + if (ret != 0) + return -9650; + + /* Check the old prime test using the number that false positives. + * This test result should indicate as not prime. */ + ret = mp_prime_is_prime(&n, 40, &isPrime); + if (ret != 0) + return -9651; + if (isPrime) + return -9652; + + /* This test result should fail. It should indicate the value as prime. */ + ret = mp_prime_is_prime(&n, 8, &isPrime); + if (ret != 0) + return -9653; + if (!isPrime) + return -9654; + + /* This test result should indicate the value as not prime. */ + ret = mp_prime_is_prime_ex(&n, 8, &isPrime, &rng); + if (ret != 0) + return -9655; + if (isPrime) + return -9656; + + ret = mp_read_unsigned_bin(&n, controlPrime, sizeof(controlPrime)); + if (ret != 0) + return -9657; + + /* This test result should indicate the value as prime. */ + ret = mp_prime_is_prime_ex(&n, 8, &isPrime, &rng); + if (ret != 0) + return -9658; + if (!isPrime) + return -9659; + + /* This test result should indicate the value as prime. */ + isPrime = -1; + ret = mp_prime_is_prime(&n, 8, &isPrime); + if (ret != 0) + return -9660; + if (!isPrime) + return -9661; + + mp_clear(&p3); + mp_clear(&p2); + mp_clear(&p1); + mp_clear(&n); + wc_FreeRng(&rng); + + return 0; +} + +#endif /* WOLFSSL_PUBLIC_MP */ + + #if defined(ASN_BER_TO_DER) && \ (defined(WOLFSSL_TEST_CERT) || defined(OPENSSL_EXTRA) || \ defined(OPENSSL_EXTRA_X509_SMALL))