diff --git a/wolfcrypt/src/integer.c b/wolfcrypt/src/integer.c index 3c86e665a..e54d192da 100644 --- a/wolfcrypt/src/integer.c +++ b/wolfcrypt/src/integer.c @@ -1145,23 +1145,28 @@ int mp_invmod_slow (mp_int * a, mp_int * b, mp_int * c) /* init temps */ if ((res = mp_init_multi(&x, &y, &u, &v, &A, &B)) != MP_OKAY) { - return res; + return res; } /* init rest of tmps temps */ if ((res = mp_init_multi(&C, &D, 0, 0, 0, 0)) != MP_OKAY) { - mp_clear(&x); - mp_clear(&y); - mp_clear(&u); - mp_clear(&v); - mp_clear(&A); - mp_clear(&B); - return res; + mp_clear(&x); + mp_clear(&y); + mp_clear(&u); + mp_clear(&v); + mp_clear(&A); + mp_clear(&B); + return res; } /* x = a, y = b */ if ((res = mp_mod(a, b, &x)) != MP_OKAY) { - goto LBL_ERR; + goto LBL_ERR; + } + if (mp_isone(&x)) { + mp_set(c, 1); + res = MP_OKAY; + goto LBL_ERR; } if ((res = mp_copy (b, &y)) != MP_OKAY) { goto LBL_ERR; @@ -1198,10 +1203,10 @@ top: if (mp_isodd (&A) == MP_YES || mp_isodd (&B) == MP_YES) { /* A = (A+y)/2, B = (B-x)/2 */ if ((res = mp_add (&A, &y, &A)) != MP_OKAY) { - goto LBL_ERR; + goto LBL_ERR; } if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) { - goto LBL_ERR; + goto LBL_ERR; } } /* A = A/2, B = B/2 */ @@ -1223,10 +1228,10 @@ top: if (mp_isodd (&C) == MP_YES || mp_isodd (&D) == MP_YES) { /* C = (C+y)/2, D = (D-x)/2 */ if ((res = mp_add (&C, &y, &C)) != MP_OKAY) { - goto LBL_ERR; + goto LBL_ERR; } if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) { - goto LBL_ERR; + goto LBL_ERR; } } /* C = C/2, D = D/2 */ diff --git a/wolfcrypt/src/tfm.c b/wolfcrypt/src/tfm.c index 1f875d0cd..dacb722d1 100644 --- a/wolfcrypt/src/tfm.c +++ b/wolfcrypt/src/tfm.c @@ -897,6 +897,9 @@ static int fp_invmod_slow (fp_int * a, fp_int * b, fp_int * c) if (b->sign == FP_NEG || fp_iszero(b) == FP_YES) { return FP_VAL; } + if (fp_iszero(a) == FP_YES) { + return FP_VAL; + } #ifdef WOLFSSL_SMALL_STACK x = (fp_int*)XMALLOC(sizeof(fp_int) * 8, NULL, DYNAMIC_TYPE_BIGINT); @@ -922,7 +925,7 @@ static int fp_invmod_slow (fp_int * a, fp_int * b, fp_int * c) fp_copy(b, y); /* 2. [modified] if x,y are both even then return an error! */ - if (fp_iseven (x) == FP_YES && fp_iseven (y) == FP_YES) { + if (fp_iseven(x) == FP_YES && fp_iseven(y) == FP_YES) { #ifdef WOLFSSL_SMALL_STACK XFREE(x, NULL, DYNAMIC_TYPE_BIGINT); #endif @@ -1022,9 +1025,14 @@ int fp_invmod(fp_int *a, fp_int *b, fp_int *c) fp_int *x, *y, *u, *v, *B, *D; #endif int neg; + int err; + + if (b->sign == FP_NEG || fp_iszero(b) == FP_YES) { + return FP_VAL; + } /* 2. [modified] b must be odd */ - if (fp_iseven (b) == FP_YES) { + if (fp_iseven(b) == FP_YES) { return fp_invmod_slow(a,b,c); } @@ -1041,6 +1049,24 @@ int fp_invmod(fp_int *a, fp_int *b, fp_int *c) fp_init(u); fp_init(v); fp_init(B); fp_init(D); + if (fp_cmp(a, b) != MP_LT) { + err = mp_mod(a, b, y); + if (err != FP_OKAY) { + #ifdef WOLFSSL_SMALL_STACK + XFREE(x, NULL, DYNAMIC_TYPE_BIGINT); + #endif + return err; + } + a = y; + } + + if (fp_iszero(a) == FP_YES) { + #ifdef WOLFSSL_SMALL_STACK + XFREE(x, NULL, DYNAMIC_TYPE_BIGINT); + #endif + return FP_VAL; + } + /* x == modulus, y == value to invert */ fp_copy(b, x);