a*p % m may overflow, do not perform naive multiplication in unit tests or undefined behavior may result. [CI SKIP]

This commit is contained in:
Nick Thompson
2018-10-26 11:19:43 -06:00
parent 3f1603938c
commit 2d463f3ee7
4 changed files with 35 additions and 34 deletions

View File

@ -6,12 +6,9 @@
*/ */
#ifndef BOOST_INTEGER_EXTENDED_EUCLIDEAN_HPP #ifndef BOOST_INTEGER_EXTENDED_EUCLIDEAN_HPP
#define BOOST_INTEGER_EXTENDED_EUCLIDEAN_HPP #define BOOST_INTEGER_EXTENDED_EUCLIDEAN_HPP
#include <tuple>
#include <limits> #include <limits>
#include <stdexcept> #include <stdexcept>
#include <boost/throw_exception.hpp> #include <boost/throw_exception.hpp>
#include <boost/assert.hpp>
#include <iostream>
namespace boost { namespace integer { namespace boost { namespace integer {
@ -71,14 +68,8 @@ euclidean_result_t<Z> extended_euclidean(Z m, Z n)
if (swapped) if (swapped)
{ {
std::swap(u1, u2); return {u0, u2, u2};
BOOST_ASSERT(u2*m+u1*n==u0);
} }
else
{
BOOST_ASSERT(u1*m+u2*n==u0);
}
return {u0, u1, u2}; return {u0, u1, u2};
} }

View File

@ -35,20 +35,19 @@ boost::optional<Z> mod_inverse(Z a, Z modulus)
return {}; return {};
} }
euclidean_result_t<Z> u = extended_euclidean(a, modulus); euclidean_result_t<Z> u = extended_euclidean(a, modulus);
Z gcd = u.gcd; if (u.gcd > 1)
if (gcd > 1)
{ {
return {}; return {};
} }
Z x = u.x;
x = x % modulus;
// x might not be in the range 0 < x < m, let's fix that: // x might not be in the range 0 < x < m, let's fix that:
while (x <= 0) while (u.x <= 0)
{ {
x += modulus; u.x += modulus;
} }
BOOST_ASSERT(x*a % modulus == 1); // While indeed this is an inexpensive and comforting check,
return x; // the multiplication overflows and hence makes the check itself buggy.
//BOOST_ASSERT(u.x*a % modulus == 1);
return u.x;
} }
}} }}

View File

@ -11,6 +11,7 @@
#include <boost/integer/extended_euclidean.hpp> #include <boost/integer/extended_euclidean.hpp>
using boost::multiprecision::int128_t; using boost::multiprecision::int128_t;
using boost::multiprecision::int256_t;
using boost::integer::extended_euclidean; using boost::integer::extended_euclidean;
using boost::integer::gcd; using boost::integer::gcd;
@ -18,26 +19,29 @@ template<class Z>
void test_extended_euclidean() void test_extended_euclidean()
{ {
std::cout << "Testing the extended Euclidean algorithm on type " << boost::typeindex::type_id<Z>().pretty_name() << "\n"; std::cout << "Testing the extended Euclidean algorithm on type " << boost::typeindex::type_id<Z>().pretty_name() << "\n";
// Stress test:
//Z max_arg = std::numeric_limits<Z>::max();
Z max_arg = 500; Z max_arg = 500;
for (Z m = 1; m < max_arg; ++m) for (Z m = max_arg; m > 0; --m)
{ {
for (Z n = 1; n < max_arg; ++n) for (Z n = m; n > 0; --n)
{ {
boost::integer::euclidean_result_t u = extended_euclidean(m, n); boost::integer::euclidean_result_t u = extended_euclidean(m, n);
Z gcdmn = gcd(m, n); int256_t gcdmn = gcd(m, n);
Z x = u.x; int256_t x = u.x;
Z y = u.y; int256_t y = u.y;
BOOST_CHECK_EQUAL(u.gcd, gcdmn); BOOST_CHECK_EQUAL(u.gcd, gcdmn);
BOOST_CHECK_EQUAL(m*x + n*y, gcdmn); BOOST_CHECK_EQUAL(m*x + n*y, gcdmn);
} }
} }
} }
BOOST_AUTO_TEST_CASE(extended_euclidean_test) BOOST_AUTO_TEST_CASE(extended_euclidean_test)
{ {
test_extended_euclidean<short int>(); test_extended_euclidean<int16_t>();
test_extended_euclidean<int>(); test_extended_euclidean<int32_t>();
test_extended_euclidean<long>(); test_extended_euclidean<int64_t>();
test_extended_euclidean<long long>();
test_extended_euclidean<int128_t>(); test_extended_euclidean<int128_t>();
} }

View File

@ -19,10 +19,15 @@ template<class Z>
void test_mod_inverse() void test_mod_inverse()
{ {
std::cout << "Testing the modular multiplicative inverse on type " << boost::typeindex::type_id<Z>().pretty_name() << "\n"; std::cout << "Testing the modular multiplicative inverse on type " << boost::typeindex::type_id<Z>().pretty_name() << "\n";
//Z max_arg = std::numeric_limits<Z>::max();
Z max_arg = 500; Z max_arg = 500;
for (Z modulus = 2; modulus < max_arg; ++modulus) for (Z modulus = 2; modulus < max_arg; ++modulus)
{ {
for (Z a = 1; a < max_arg; ++a) if (modulus % 1000 == 0)
{
std::cout << "Testing all inverses modulo " << modulus << std::endl;
}
for (Z a = 1; a < modulus; ++a)
{ {
Z gcdam = gcd(a, modulus); Z gcdam = gcd(a, modulus);
boost::optional<Z> inv_a = mod_inverse(a, modulus); boost::optional<Z> inv_a = mod_inverse(a, modulus);
@ -34,7 +39,11 @@ void test_mod_inverse()
else else
{ {
BOOST_CHECK(inv_a.value() > 0); BOOST_CHECK(inv_a.value() > 0);
Z outta_be_one = (inv_a.value()*a) % modulus; // Cast to a bigger type so the multiplication won't overflow.
int256_t a_inv = inv_a.value();
int256_t big_a = a;
int256_t m = modulus;
int256_t outta_be_one = (a_inv*big_a) % m;
BOOST_CHECK_EQUAL(outta_be_one, 1); BOOST_CHECK_EQUAL(outta_be_one, 1);
} }
} }
@ -43,10 +52,8 @@ void test_mod_inverse()
BOOST_AUTO_TEST_CASE(mod_inverse_test) BOOST_AUTO_TEST_CASE(mod_inverse_test)
{ {
test_mod_inverse<short int>(); test_mod_inverse<int16_t>();
test_mod_inverse<int>(); test_mod_inverse<int32_t>();
test_mod_inverse<long>(); test_mod_inverse<int64_t>();
test_mod_inverse<long long>();
test_mod_inverse<int128_t>(); test_mod_inverse<int128_t>();
test_mod_inverse<int256_t>();
} }