Fix for mp_mulmod with NXP LTC.

This commit is contained in:
David Garske
2021-06-01 16:33:58 -07:00
parent 8bf2cbf55e
commit eb63ab19e2

View File

@ -44,6 +44,9 @@
/* For debugging only - Enable this to do software tests of each operation */
/* #define ENABLE_NXPLTC_TESTS */
#ifdef ENABLE_NXPLTC_TESTS
static int doLtcTest = 0;
#endif
int ksdk_port_init(void)
{
@ -111,7 +114,6 @@ static int ltc_get_lsb_bin_from_mp_int(uint8_t *dst, mp_int *A, uint16_t *psz)
/* LTC TFM */
#if defined(FREESCALE_LTC_TFM)
/* these function are used by wolfSSL upper layers (like RSA) */
/* c = a * b */
@ -124,7 +126,8 @@ int mp_mul(mp_int *A, mp_int *B, mp_int *C)
#ifdef ENABLE_NXPLTC_TESTS
mp_int t;
mp_init(&t);
wolfcrypt_mp_mul(A, B, &t);
if (doLtcTest)
wolfcrypt_mp_mul(A, B, &t);
#endif
szA = mp_unsigned_bin_size(A);
@ -154,10 +157,13 @@ int mp_mul(mp_int *A, mp_int *B, mp_int *C)
XMEMSET(ptrN, 0xFF, sizeN);
XMEMSET(ptrC, 0, LTC_MAX_INT_BYTES);
status = LTC_PKHA_ModMul(LTC_BASE, ptrA, sizeA, ptrB, sizeB,
ptrN, sizeN, ptrC, &sizeC, kLTC_PKHA_IntegerArith,
kLTC_PKHA_NormalValue, kLTC_PKHA_NormalValue,
kLTC_PKHA_TimingEqualized);
status = LTC_PKHA_ModMul(LTC_BASE,
ptrA, sizeA, /* first integer */
ptrB, sizeB, /* second integer */
ptrN, sizeN, /* modulus */
ptrC, &sizeC, /* out */
kLTC_PKHA_IntegerArith, kLTC_PKHA_NormalValue,
kLTC_PKHA_NormalValue, kLTC_PKHA_TimingEqualized);
if (status == kStatus_Success) {
ltc_reverse_array(ptrC, sizeC);
res = mp_read_unsigned_bin(C, ptrC, sizeC);
@ -196,7 +202,7 @@ int mp_mul(mp_int *A, mp_int *B, mp_int *C)
#ifdef ENABLE_NXPLTC_TESTS
/* compare hardware vs software */
if (mp_cmp(&t, C) != MP_EQ) {
if (doLtcTest && mp_cmp(&t, C) != MP_EQ) {
printf("mp_mul test fail!\n");
mp_dump("C", C, 0);
@ -217,7 +223,8 @@ int mp_mod(mp_int *a, mp_int *b, mp_int *c)
#ifdef ENABLE_NXPLTC_TESTS
mp_int t;
mp_init(&t);
wolfcrypt_mp_mod(a, b, &t);
if (doLtcTest)
wolfcrypt_mp_mod(a, b, &t);
#endif
szA = mp_unsigned_bin_size(a);
@ -282,7 +289,7 @@ int mp_mod(mp_int *a, mp_int *b, mp_int *c)
#ifdef ENABLE_NXPLTC_TESTS
/* compare hardware vs software */
if (mp_cmp(&t, c) != MP_EQ) {
if (doLtcTest && mp_cmp(&t, c) != MP_EQ) {
printf("mp_mod test fail!\n");
mp_dump("C", c, 0);
@ -303,7 +310,8 @@ int mp_invmod(mp_int *a, mp_int *b, mp_int *c)
#ifdef ENABLE_NXPLTC_TESTS
mp_int t;
mp_init(&t);
wolfcrypt_mp_invmod(a, b, &t);
if (doLtcTest)
wolfcrypt_mp_invmod(a, b, &t);
#endif
szA = mp_unsigned_bin_size(a);
@ -321,8 +329,8 @@ int mp_invmod(mp_int *a, mp_int *b, mp_int *c)
res = ltc_get_lsb_bin_from_mp_int(ptrB, b, &sizeB);
/* if a >= b then reduce */
if (res == MP_OKAY && LTC_PKHA_CompareBigNum(ptrA, sizeA, ptrB,
sizeB) >= 0) {
if (res == MP_OKAY &&
LTC_PKHA_CompareBigNum(ptrA, sizeA, ptrB, sizeB) >= 0) {
if (LTC_PKHA_ModRed(LTC_BASE, ptrA, sizeA, ptrB, sizeB,
ptrA, &sizeA, kLTC_PKHA_IntegerArith) != kStatus_Success) {
res = MP_VAL;
@ -367,7 +375,7 @@ int mp_invmod(mp_int *a, mp_int *b, mp_int *c)
#ifdef ENABLE_NXPLTC_TESTS
/* compare hardware vs software */
if (mp_cmp(&t, c) != MP_EQ) {
if (doLtcTest && mp_cmp(&t, c) != MP_EQ) {
printf("mp_invmod test fail!\n");
mp_dump("C", c, 0);
@ -389,7 +397,8 @@ int mp_mulmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
#ifdef ENABLE_NXPLTC_TESTS
mp_int t;
mp_init(&t);
wolfcrypt_mp_mulmod(a, b, c, &t);
if (doLtcTest)
wolfcrypt_mp_mulmod(a, b, c, &t);
#endif
szA = mp_unsigned_bin_size(a);
@ -399,77 +408,63 @@ int mp_mulmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
if ((szA <= LTC_MAX_INT_BYTES) && (szB <= LTC_MAX_INT_BYTES) &&
(szC <= LTC_MAX_INT_BYTES))
{
int neg = 0;
uint8_t *ptrA, *ptrB, *ptrC, *ptrD;
#ifndef WOLFSSL_SP_MATH
/* if A or B is negative, subtract abs(A) or abs(B) from modulus to get
* positive integer representation of the same number */
mp_int aabs, babs;
res = mp_init_multi(&aabs, &babs, NULL, NULL, NULL, NULL);
if (res != MP_OKAY) {
return res;
}
if (a->sign)
res = mp_add(a, c, &aabs);
else
res = mp_copy(a, &aabs);
if (res == MP_OKAY) {
if (b->sign)
res = mp_add(b, c, &babs);
else
res = mp_copy(b, &babs);
}
if (res != MP_OKAY) {
mp_clear(&aabs);
mp_clear(&babs);
return res;
}
#endif
ptrA = (uint8_t*)XMALLOC(LTC_MAX_INT_BYTES, NULL, DYNAMIC_TYPE_BIGINT);
ptrB = (uint8_t*)XMALLOC(LTC_MAX_INT_BYTES, NULL, DYNAMIC_TYPE_BIGINT);
ptrC = (uint8_t*)XMALLOC(LTC_MAX_INT_BYTES, NULL, DYNAMIC_TYPE_BIGINT);
ptrD = (uint8_t*)XMALLOC(LTC_MAX_INT_BYTES, NULL, DYNAMIC_TYPE_BIGINT);
/* unsigned multiply */
#ifndef WOLFSSL_SP_MATH
neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
#endif
if (ptrA && ptrB && ptrC && ptrD) {
uint16_t sizeA, sizeB, sizeC, sizeD;
uint16_t sizeA, sizeB, sizeC, sizeD = 0;
res = ltc_get_lsb_bin_from_mp_int(ptrA, &aabs, &sizeA);
/* Multiply A * B = D */
res = ltc_get_lsb_bin_from_mp_int(ptrA, a, &sizeA);
if (res == MP_OKAY)
res = ltc_get_lsb_bin_from_mp_int(ptrB, &babs, &sizeB);
if (res == MP_OKAY)
res = ltc_get_lsb_bin_from_mp_int(ptrC, c, &sizeC);
/* (A*B)mod C = ((A mod C) * (B mod C)) mod C */
if (res == MP_OKAY && LTC_PKHA_CompareBigNum(ptrA, sizeA, ptrC,
sizeC) >= 0) {
status = LTC_PKHA_ModRed(LTC_BASE, ptrA, sizeA, ptrC, sizeC,
ptrA, &sizeA, kLTC_PKHA_IntegerArith);
if (status != kStatus_Success) {
res = MP_VAL;
}
}
if (res == MP_OKAY && (LTC_PKHA_CompareBigNum(ptrB, sizeB, ptrC,
sizeC) >= 0)) {
status = LTC_PKHA_ModRed(LTC_BASE, ptrB, sizeB, ptrC, sizeC,
ptrB, &sizeB, kLTC_PKHA_IntegerArith);
if (status != kStatus_Success) {
res = MP_VAL;
}
}
res = ltc_get_lsb_bin_from_mp_int(ptrB, b, &sizeB);
if (res == MP_OKAY) {
status = LTC_PKHA_ModMul(LTC_BASE, ptrA, sizeA,
ptrB, sizeB, ptrC, sizeC, ptrD, &sizeD,
/* modulus C is all F's for integer multiply */
sizeC = sizeA + sizeB;
XMEMSET(ptrC, 0xFF, sizeC);
XMEMSET(ptrD, 0, LTC_MAX_INT_BYTES);
status = LTC_PKHA_ModMul(LTC_BASE,
ptrA, sizeA, /* first integer */
ptrB, sizeB, /* second integer */
ptrC, sizeC, /* modulus */
ptrD, &sizeD, /* out */
kLTC_PKHA_IntegerArith, kLTC_PKHA_NormalValue,
kLTC_PKHA_NormalValue, kLTC_PKHA_TimingEqualized);
if (status == kStatus_Success) {
ltc_reverse_array(ptrD, sizeD);
res = mp_read_unsigned_bin(d, ptrD, sizeD);
}
else {
if (status != kStatus_Success)
res = MP_VAL;
}
}
/* load modulus */
if (res == MP_OKAY)
res = ltc_get_lsb_bin_from_mp_int(ptrC, c, &sizeC);
/* perform D mod C = D */
if (res == MP_OKAY) {
status = LTC_PKHA_ModRed(LTC_BASE,
ptrD, sizeD,
ptrC, sizeC,
ptrD, &sizeD,
kLTC_PKHA_IntegerArith);
if (status != kStatus_Success)
res = MP_VAL;
}
if (res == MP_OKAY) {
ltc_reverse_array(ptrD, sizeD);
res = mp_read_unsigned_bin(d, ptrD, sizeD);
#ifndef WOLFSSL_SP_MATH
/* fix sign */
d->sign = neg;
#endif
}
}
else {
@ -488,11 +483,6 @@ int mp_mulmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
if (ptrD) {
XFREE(ptrD, NULL, DYNAMIC_TYPE_BIGINT);
}
#ifndef WOLFSSL_SP_MATH
mp_clear(&aabs);
mp_clear(&babs);
#endif
}
else {
#if defined(FREESCALE_LTC_TFM_RSA_4096_ENABLE)
@ -504,7 +494,7 @@ int mp_mulmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
#ifdef ENABLE_NXPLTC_TESTS
/* compare hardware vs software */
if (mp_cmp(&t, d) != MP_EQ) {
if (doLtcTest && mp_cmp(&t, d) != MP_EQ) {
printf("mp_mulmod test fail!\n");
mp_dump("D", d, 0);
@ -520,111 +510,77 @@ int mp_mulmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
int mp_exptmod(mp_int *G, mp_int *X, mp_int *P, mp_int *Y)
{
int res = MP_OKAY;
int szA, szB, szC;
#if defined(FREESCALE_LTC_TFM_RSA_4096_ENABLE)
mp_int tmp;
#endif
int szG, szX, szP;
#ifdef ENABLE_NXPLTC_TESTS
mp_int t;
mp_init(&t);
res = wolfcrypt_mp_exptmod(G, X, P, Y);
if (doLtcTest)
wolfcrypt_mp_exptmod(G, X, P, &t);
#endif
/* if G cannot fit into LTC_PKHA, reduce it */
szA = mp_unsigned_bin_size(G);
#if defined(FREESCALE_LTC_TFM_RSA_4096_ENABLE)
if (szA > LTC_MAX_INT_BYTES) {
res = mp_init(&tmp);
if (res != MP_OKAY)
return res;
if ((res = mp_mod(G, P, &tmp)) != MP_OKAY) {
mp_clear(&tmp);
return res;
}
G = &tmp;
szA = mp_unsigned_bin_size(G);
}
#endif
szB = mp_unsigned_bin_size(X);
szC = mp_unsigned_bin_size(P);
szG = mp_unsigned_bin_size(G);
szX = mp_unsigned_bin_size(X);
szP = mp_unsigned_bin_size(P);
if ((szA <= LTC_MAX_INT_BYTES) &&
(szB <= LTC_MAX_INT_BYTES) &&
(szC <= LTC_MAX_INT_BYTES))
if ((szG <= LTC_MAX_INT_BYTES) &&
(szX <= LTC_MAX_INT_BYTES) &&
(szP <= LTC_MAX_INT_BYTES))
{
uint16_t sizeG, sizeX, sizeP;
uint8_t *ptrG, *ptrX, *ptrP;
/* if G is negative, add modulus to convert to positive number for LTC */
#ifndef WOLFSSL_SP_MATH
mp_int gabs;
res = mp_init(&gabs);
if (G->sign)
res = mp_add(G, P, &gabs);
else
res = mp_copy(G, &gabs);
if (res != MP_OKAY) {
mp_clear(&gabs);
return res;
}
#endif
uint16_t sizeG, sizeX, sizeP, sizeY;
uint8_t *ptrG, *ptrX, *ptrP, *ptrY;
ptrG = (uint8_t*)XMALLOC(LTC_MAX_INT_BYTES, NULL, DYNAMIC_TYPE_BIGINT);
ptrX = (uint8_t*)XMALLOC(LTC_MAX_INT_BYTES, NULL, DYNAMIC_TYPE_BIGINT);
ptrP = (uint8_t*)XMALLOC(LTC_MAX_INT_BYTES, NULL, DYNAMIC_TYPE_BIGINT);
if (ptrG && ptrX && ptrP) {
res = ltc_get_lsb_bin_from_mp_int(ptrG, &gabs, &sizeG);
ptrY = (uint8_t*)XMALLOC(LTC_MAX_INT_BYTES, NULL, DYNAMIC_TYPE_BIGINT);
if (ptrG && ptrX && ptrP && ptrY) {
res = ltc_get_lsb_bin_from_mp_int(ptrG, G, &sizeG);
if (res == MP_OKAY)
res = ltc_get_lsb_bin_from_mp_int(ptrX, X, &sizeX);
if (res == MP_OKAY)
res = ltc_get_lsb_bin_from_mp_int(ptrP, P, &sizeP);
/* if number if greater that modulo, we must first reduce due to LTC
requirement on modular exponentiation */
/* it needs number less than modulus. */
/* we can take advantage of modular arithmetic rule that:
A^B mod C = ( (A mod C)^B ) mod C
and so we do first (A mod N) : LTC does not give size requirement
on A versus N, and then the modular exponentiation.
*/
/* if G >= P then reduce */
if (res == MP_OKAY && LTC_PKHA_CompareBigNum(ptrG, sizeG, ptrP,
sizeP) >= 0) {
if (LTC_PKHA_ModRed(LTC_BASE, ptrG, sizeG, ptrP, sizeP,
ptrG, &sizeG, kLTC_PKHA_IntegerArith) != kStatus_Success) {
res = MP_VAL;
}
if (res == MP_OKAY &&
LTC_PKHA_CompareBigNum(ptrG, sizeG, ptrP, sizeP) >= 0) {
res = LTC_PKHA_ModRed(LTC_BASE,
ptrG, sizeG,
ptrP, sizeP,
ptrG, &sizeG, kLTC_PKHA_IntegerArith);
res = (res == kStatus_Success) ? MP_OKAY: MP_VAL;
}
if (res == MP_OKAY) {
if (LTC_PKHA_ModExp(LTC_BASE, ptrG, sizeG, ptrP, sizeP, ptrX, sizeX,
ptrP, &sizeP, kLTC_PKHA_IntegerArith, kLTC_PKHA_NormalValue,
kLTC_PKHA_TimingEqualized) != kStatus_Success) {
res = MP_VAL;
}
else {
ltc_reverse_array(ptrP, sizeP);
res = mp_read_unsigned_bin(Y, ptrP, sizeP);
}
res = LTC_PKHA_ModExp(LTC_BASE,
ptrG, sizeG, /* integer input */
ptrP, sizeP, /* modulus */
ptrX, sizeX, /* expenoent */
ptrY, &sizeY, /* out */
kLTC_PKHA_IntegerArith, kLTC_PKHA_NormalValue,
kLTC_PKHA_TimingEqualized);
res = (res == kStatus_Success) ? MP_OKAY: MP_VAL;
}
if (res == MP_OKAY) {
ltc_reverse_array(ptrY, sizeY);
res = mp_read_unsigned_bin(Y, ptrY, sizeY);
}
}
else {
res = MP_MEM;
}
if (ptrG) {
XFREE(ptrG, NULL, DYNAMIC_TYPE_BIGINT);
}
if (ptrX) {
XFREE(ptrX, NULL, DYNAMIC_TYPE_BIGINT);
if (ptrY) {
XFREE(ptrY, NULL, DYNAMIC_TYPE_BIGINT);
}
if (ptrP) {
XFREE(ptrP, NULL, DYNAMIC_TYPE_BIGINT);
}
#ifndef WOLFSSL_SP_MATH
mp_clear(&gabs);
#endif
if (ptrX) {
XFREE(ptrX, NULL, DYNAMIC_TYPE_BIGINT);
}
if (ptrG) {
XFREE(ptrG, NULL, DYNAMIC_TYPE_BIGINT);
}
}
else {
#if defined(FREESCALE_LTC_TFM_RSA_4096_ENABLE)
@ -636,7 +592,7 @@ int mp_exptmod(mp_int *G, mp_int *X, mp_int *P, mp_int *Y)
#ifdef ENABLE_NXPLTC_TESTS
/* compare hardware vs software */
if (mp_cmp(&t, Y) != MP_EQ) {
if (doLtcTest && mp_cmp(&t, Y) != MP_EQ) {
printf("mp_exptmod test fail!\n");
mp_dump("Y", Y, 0);
@ -645,11 +601,6 @@ int mp_exptmod(mp_int *G, mp_int *X, mp_int *P, mp_int *Y)
mp_clear(&t);
#endif
#ifndef USE_FAST_MATH
if (szA > LTC_MAX_INT_BYTES)
mp_clear(&tmp);
#endif
return res;
}
@ -668,7 +619,8 @@ int mp_prime_is_prime_ex(mp_int* a, int t, int* result, WC_RNG* rng)
#ifdef ENABLE_NXPLTC_TESTS
int result_soft = 0;
res = mp_prime_is_prime_ex(a, t, &result_soft, rng);
if (doLtcTest)
mp_prime_is_prime_ex(a, t, &result_soft, rng);
#endif
szA = mp_unsigned_bin_size(a);
@ -726,7 +678,7 @@ int mp_prime_is_prime_ex(mp_int* a, int t, int* result, WC_RNG* rng)
#ifdef ENABLE_NXPLTC_TESTS
/* compare hardware vs software */
if (*result != result_soft) {
if (doLtcTest && *result != result_soft) {
printf("Fail! mp_prime_is_prime_ex %d != %d\n", *result, result_soft);
}
#endif
@ -1164,7 +1116,7 @@ status_t LTC_PKHA_Prime25519SquareRootMod(const uint8_t *A, size_t sizeA,
/* I = I - 1 */
XMEMSET(VV, 0xff, sizeof(VV)); /* just temp for maximum integer - for non-modular subtract */
if (0 <= LTC_PKHA_CompareBigNum(I, szI, &one, sizeof(one))) {
if (LTC_PKHA_CompareBigNum(I, szI, &one, sizeof(one)) >= 0) {
if (status == kStatus_Success) {
status = LTC_PKHA_ModSub1(LTC_BASE, I, szI, &one, sizeof(one),
VV, sizeof(VV), I, &szI);