diff --git a/include/boost/integer/extended_euclidean.hpp b/include/boost/integer/extended_euclidean.hpp index edca600..f89a540 100644 --- a/include/boost/integer/extended_euclidean.hpp +++ b/include/boost/integer/extended_euclidean.hpp @@ -6,12 +6,9 @@ */ #ifndef BOOST_INTEGER_EXTENDED_EUCLIDEAN_HPP #define BOOST_INTEGER_EXTENDED_EUCLIDEAN_HPP -#include #include #include #include -#include -#include namespace boost { namespace integer { @@ -71,14 +68,8 @@ euclidean_result_t extended_euclidean(Z m, Z n) if (swapped) { - std::swap(u1, u2); - BOOST_ASSERT(u2*m+u1*n==u0); + return {u0, u2, u2}; } - else - { - BOOST_ASSERT(u1*m+u2*n==u0); - } - return {u0, u1, u2}; } diff --git a/include/boost/integer/mod_inverse.hpp b/include/boost/integer/mod_inverse.hpp index 9a93585..7053efc 100644 --- a/include/boost/integer/mod_inverse.hpp +++ b/include/boost/integer/mod_inverse.hpp @@ -35,20 +35,19 @@ boost::optional mod_inverse(Z a, Z modulus) return {}; } euclidean_result_t u = extended_euclidean(a, modulus); - Z gcd = u.gcd; - if (gcd > 1) + if (u.gcd > 1) { return {}; } - Z x = u.x; - x = x % modulus; // x might not be in the range 0 < x < m, let's fix that: - while (x <= 0) + while (u.x <= 0) { - x += modulus; + u.x += modulus; } - BOOST_ASSERT(x*a % modulus == 1); - return x; + // While indeed this is an inexpensive and comforting check, + // the multiplication overflows and hence makes the check itself buggy. + //BOOST_ASSERT(u.x*a % modulus == 1); + return u.x; } }} diff --git a/test/extended_euclidean_test.cpp b/test/extended_euclidean_test.cpp index 2179016..16fc0ff 100644 --- a/test/extended_euclidean_test.cpp +++ b/test/extended_euclidean_test.cpp @@ -11,6 +11,7 @@ #include using boost::multiprecision::int128_t; +using boost::multiprecision::int256_t; using boost::integer::extended_euclidean; using boost::integer::gcd; @@ -18,26 +19,29 @@ template void test_extended_euclidean() { std::cout << "Testing the extended Euclidean algorithm on type " << boost::typeindex::type_id().pretty_name() << "\n"; + // Stress test: + //Z max_arg = std::numeric_limits::max(); Z max_arg = 500; - for (Z m = 1; m < max_arg; ++m) + for (Z m = max_arg; m > 0; --m) { - for (Z n = 1; n < max_arg; ++n) + for (Z n = m; n > 0; --n) { boost::integer::euclidean_result_t u = extended_euclidean(m, n); - Z gcdmn = gcd(m, n); - Z x = u.x; - Z y = u.y; + int256_t gcdmn = gcd(m, n); + int256_t x = u.x; + int256_t y = u.y; BOOST_CHECK_EQUAL(u.gcd, gcdmn); BOOST_CHECK_EQUAL(m*x + n*y, gcdmn); } } } + + BOOST_AUTO_TEST_CASE(extended_euclidean_test) { - test_extended_euclidean(); - test_extended_euclidean(); - test_extended_euclidean(); - test_extended_euclidean(); + test_extended_euclidean(); + test_extended_euclidean(); + test_extended_euclidean(); test_extended_euclidean(); } diff --git a/test/mod_inverse_test.cpp b/test/mod_inverse_test.cpp index ac7862f..4ab02f8 100644 --- a/test/mod_inverse_test.cpp +++ b/test/mod_inverse_test.cpp @@ -19,10 +19,15 @@ template void test_mod_inverse() { std::cout << "Testing the modular multiplicative inverse on type " << boost::typeindex::type_id().pretty_name() << "\n"; + //Z max_arg = std::numeric_limits::max(); Z max_arg = 500; for (Z modulus = 2; modulus < max_arg; ++modulus) { - for (Z a = 1; a < max_arg; ++a) + if (modulus % 1000 == 0) + { + std::cout << "Testing all inverses modulo " << modulus << std::endl; + } + for (Z a = 1; a < modulus; ++a) { Z gcdam = gcd(a, modulus); boost::optional inv_a = mod_inverse(a, modulus); @@ -34,7 +39,11 @@ void test_mod_inverse() else { BOOST_CHECK(inv_a.value() > 0); - Z outta_be_one = (inv_a.value()*a) % modulus; + // Cast to a bigger type so the multiplication won't overflow. + int256_t a_inv = inv_a.value(); + int256_t big_a = a; + int256_t m = modulus; + int256_t outta_be_one = (a_inv*big_a) % m; BOOST_CHECK_EQUAL(outta_be_one, 1); } } @@ -43,10 +52,8 @@ void test_mod_inverse() BOOST_AUTO_TEST_CASE(mod_inverse_test) { - test_mod_inverse(); - test_mod_inverse(); - test_mod_inverse(); - test_mod_inverse(); + test_mod_inverse(); + test_mod_inverse(); + test_mod_inverse(); test_mod_inverse(); - test_mod_inverse(); }