Fixes for SRP with heap hint.

This commit is contained in:
David Garske
2021-08-17 10:45:50 -07:00
parent c598688f89
commit e1f603301b
3 changed files with 31 additions and 22 deletions

View File

@@ -42,35 +42,35 @@
/** Computes the session key using the Mask Generation Function 1. */ /** Computes the session key using the Mask Generation Function 1. */
static int wc_SrpSetKey(Srp* srp, byte* secret, word32 size); static int wc_SrpSetKey(Srp* srp, byte* secret, word32 size);
static int SrpHashInit(SrpHash* hash, SrpType type) static int SrpHashInit(SrpHash* hash, SrpType type, void* heap)
{ {
hash->type = type; hash->type = type;
switch (type) { switch (type) {
case SRP_TYPE_SHA: case SRP_TYPE_SHA:
#ifndef NO_SHA #ifndef NO_SHA
return wc_InitSha(&hash->data.sha); return wc_InitSha_ex(&hash->data.sha, heap, INVALID_DEVID);
#else #else
return BAD_FUNC_ARG; return BAD_FUNC_ARG;
#endif #endif
case SRP_TYPE_SHA256: case SRP_TYPE_SHA256:
#ifndef NO_SHA256 #ifndef NO_SHA256
return wc_InitSha256(&hash->data.sha256); return wc_InitSha256_ex(&hash->data.sha256, heap, INVALID_DEVID);
#else #else
return BAD_FUNC_ARG; return BAD_FUNC_ARG;
#endif #endif
case SRP_TYPE_SHA384: case SRP_TYPE_SHA384:
#ifdef WOLFSSL_SHA384 #ifdef WOLFSSL_SHA384
return wc_InitSha384(&hash->data.sha384); return wc_InitSha384_ex(&hash->data.sha384, heap, INVALID_DEVID);
#else #else
return BAD_FUNC_ARG; return BAD_FUNC_ARG;
#endif #endif
case SRP_TYPE_SHA512: case SRP_TYPE_SHA512:
#ifdef WOLFSSL_SHA512 #ifdef WOLFSSL_SHA512
return wc_InitSha512(&hash->data.sha512); return wc_InitSha512_ex(&hash->data.sha512, heap, INVALID_DEVID);
#else #else
return BAD_FUNC_ARG; return BAD_FUNC_ARG;
#endif #endif
@@ -217,7 +217,7 @@ static void SrpHashFree(SrpHash* hash)
} }
int wc_SrpInit(Srp* srp, SrpType type, SrpSide side) int wc_SrpInit_ex(Srp* srp, SrpType type, SrpSide side, void* heap, int devId)
{ {
int r; int r;
@@ -265,10 +265,10 @@ int wc_SrpInit(Srp* srp, SrpType type, SrpSide side)
/* initializing variables */ /* initializing variables */
XMEMSET(srp, 0, sizeof(Srp)); XMEMSET(srp, 0, sizeof(Srp));
if ((r = SrpHashInit(&srp->client_proof, type)) != 0) if ((r = SrpHashInit(&srp->client_proof, type, srp->heap)) != 0)
return r; return r;
if ((r = SrpHashInit(&srp->server_proof, type)) != 0) { if ((r = SrpHashInit(&srp->server_proof, type, srp->heap)) != 0) {
SrpHashFree(&srp->client_proof); SrpHashFree(&srp->client_proof);
return r; return r;
} }
@@ -288,14 +288,21 @@ int wc_SrpInit(Srp* srp, SrpType type, SrpSide side)
/* default heap hint to NULL or test value */ /* default heap hint to NULL or test value */
#ifdef WOLFSSL_HEAP_TEST #ifdef WOLFSSL_HEAP_TEST
srp->heap = (void*)WOLFSSL_HEAP_TEST; srp->heap = (void)WOLFSSL_HEAP_TEST;
#else #else
srp->heap = NULL; srp->heap = heap;
#endif #endif /* WOLFSSL_HEAP_TEST */
(void)devId; /* future */
return 0; return 0;
} }
int wc_SrpInit(Srp* srp, SrpType type, SrpSide side)
{
return wc_SrpInit_ex(srp, type, side, NULL, INVALID_DEVID);
}
void wc_SrpTerm(Srp* srp) void wc_SrpTerm(Srp* srp)
{ {
if (srp) { if (srp) {
@@ -382,7 +389,7 @@ int wc_SrpSetParams(Srp* srp, const byte* N, word32 nSz,
srp->saltSz = saltSz; srp->saltSz = saltSz;
/* Set k = H(N, g) */ /* Set k = H(N, g) */
r = SrpHashInit(&hash, srp->type); r = SrpHashInit(&hash, srp->type, srp->heap);
if (!r) r = SrpHashUpdate(&hash, (byte*) N, nSz); if (!r) r = SrpHashUpdate(&hash, (byte*) N, nSz);
for (i = 0; (word32)i < nSz - gSz; i++) { for (i = 0; (word32)i < nSz - gSz; i++) {
if (!r) r = SrpHashUpdate(&hash, &pad, 1); if (!r) r = SrpHashUpdate(&hash, &pad, 1);
@@ -394,13 +401,13 @@ int wc_SrpSetParams(Srp* srp, const byte* N, word32 nSz,
/* update client proof */ /* update client proof */
/* digest1 = H(N) */ /* digest1 = H(N) */
if (!r) r = SrpHashInit(&hash, srp->type); if (!r) r = SrpHashInit(&hash, srp->type, srp->heap);
if (!r) r = SrpHashUpdate(&hash, (byte*) N, nSz); if (!r) r = SrpHashUpdate(&hash, (byte*) N, nSz);
if (!r) r = SrpHashFinal(&hash, digest1); if (!r) r = SrpHashFinal(&hash, digest1);
SrpHashFree(&hash); SrpHashFree(&hash);
/* digest2 = H(g) */ /* digest2 = H(g) */
if (!r) r = SrpHashInit(&hash, srp->type); if (!r) r = SrpHashInit(&hash, srp->type, srp->heap);
if (!r) r = SrpHashUpdate(&hash, (byte*) g, gSz); if (!r) r = SrpHashUpdate(&hash, (byte*) g, gSz);
if (!r) r = SrpHashFinal(&hash, digest2); if (!r) r = SrpHashFinal(&hash, digest2);
SrpHashFree(&hash); SrpHashFree(&hash);
@@ -412,7 +419,7 @@ int wc_SrpSetParams(Srp* srp, const byte* N, word32 nSz,
} }
/* digest2 = H(user) */ /* digest2 = H(user) */
if (!r) r = SrpHashInit(&hash, srp->type); if (!r) r = SrpHashInit(&hash, srp->type, srp->heap);
if (!r) r = SrpHashUpdate(&hash, srp->user, srp->userSz); if (!r) r = SrpHashUpdate(&hash, srp->user, srp->userSz);
if (!r) r = SrpHashFinal(&hash, digest2); if (!r) r = SrpHashFinal(&hash, digest2);
SrpHashFree(&hash); SrpHashFree(&hash);
@@ -441,7 +448,7 @@ int wc_SrpSetPassword(Srp* srp, const byte* password, word32 size)
digestSz = SrpHashSize(srp->type); digestSz = SrpHashSize(srp->type);
/* digest = H(username | ':' | password) */ /* digest = H(username | ':' | password) */
r = SrpHashInit(&hash, srp->type); r = SrpHashInit(&hash, srp->type, srp->heap);
if (!r) r = SrpHashUpdate(&hash, srp->user, srp->userSz); if (!r) r = SrpHashUpdate(&hash, srp->user, srp->userSz);
if (!r) r = SrpHashUpdate(&hash, (const byte*) ":", 1); if (!r) r = SrpHashUpdate(&hash, (const byte*) ":", 1);
if (!r) r = SrpHashUpdate(&hash, password, size); if (!r) r = SrpHashUpdate(&hash, password, size);
@@ -449,7 +456,7 @@ int wc_SrpSetPassword(Srp* srp, const byte* password, word32 size)
SrpHashFree(&hash); SrpHashFree(&hash);
/* digest = H(salt | H(username | ':' | password)) */ /* digest = H(salt | H(username | ':' | password)) */
if (!r) r = SrpHashInit(&hash, srp->type); if (!r) r = SrpHashInit(&hash, srp->type, srp->heap);
if (!r) r = SrpHashUpdate(&hash, srp->salt, srp->saltSz); if (!r) r = SrpHashUpdate(&hash, srp->salt, srp->saltSz);
if (!r) r = SrpHashUpdate(&hash, digest, digestSz); if (!r) r = SrpHashUpdate(&hash, digest, digestSz);
if (!r) r = SrpHashFinal(&hash, digest); if (!r) r = SrpHashFinal(&hash, digest);
@@ -524,7 +531,7 @@ int wc_SrpSetPrivate(Srp* srp, const byte* priv, word32 size)
static int wc_SrpGenPrivate(Srp* srp, byte* priv, word32 size) static int wc_SrpGenPrivate(Srp* srp, byte* priv, word32 size)
{ {
WC_RNG rng; WC_RNG rng;
int r = wc_InitRng(&rng); int r = wc_InitRng_ex(&rng, srp->heap, INVALID_DEVID);
if (!r) r = wc_RNG_GenerateBlock(&rng, priv, size); if (!r) r = wc_RNG_GenerateBlock(&rng, priv, size);
if (!r) r = wc_SrpSetPrivate(srp, priv, size); if (!r) r = wc_SrpSetPrivate(srp, priv, size);
@@ -608,7 +615,7 @@ static int wc_SrpSetKey(Srp* srp, byte* secret, word32 size)
counter[2] = (i >> 8) & 0xFF; counter[2] = (i >> 8) & 0xFF;
counter[3] = i & 0xFF; counter[3] = i & 0xFF;
r = SrpHashInit(&hash, srp->type); r = SrpHashInit(&hash, srp->type, srp->heap);
if (!r) r = SrpHashUpdate(&hash, secret, size); if (!r) r = SrpHashUpdate(&hash, secret, size);
if (!r) r = SrpHashUpdate(&hash, counter, 4); if (!r) r = SrpHashUpdate(&hash, counter, 4);
@@ -688,7 +695,7 @@ int wc_SrpComputeKey(Srp* srp, byte* clientPubKey, word32 clientPubKeySz,
/* initializing variables */ /* initializing variables */
if ((r = SrpHashInit(hash, srp->type)) != 0) if ((r = SrpHashInit(hash, srp->type, srp->heap)) != 0)
goto out; goto out;
digestSz = SrpHashSize(srp->type); digestSz = SrpHashSize(srp->type);

View File

@@ -17262,7 +17262,7 @@ static int srp_test_digest(int dgstType)
/* client knows username and password. */ /* client knows username and password. */
/* server knows N, g, salt and verifier. */ /* server knows N, g, salt and verifier. */
if (!r) r = wc_SrpInit(cli, dgstType, SRP_CLIENT_SIDE); if (!r) r = wc_SrpInit_ex(cli, dgstType, SRP_CLIENT_SIDE, HEAP_HINT, devId);
if (!r) r = wc_SrpSetUsername(cli, username, usernameSz); if (!r) r = wc_SrpSetUsername(cli, username, usernameSz);
/* loading N, g and salt in advance to generate the verifier. */ /* loading N, g and salt in advance to generate the verifier. */
@@ -17275,7 +17275,7 @@ static int srp_test_digest(int dgstType)
/* client sends username to server */ /* client sends username to server */
if (!r) r = wc_SrpInit(srv, dgstType, SRP_SERVER_SIDE); if (!r) r = wc_SrpInit_ex(srv, dgstType, SRP_SERVER_SIDE, HEAP_HINT, devId);
if (!r) r = wc_SrpSetUsername(srv, username, usernameSz); if (!r) r = wc_SrpSetUsername(srv, username, usernameSz);
if (!r) r = wc_SrpSetParams(srv, N, sizeof(N), if (!r) r = wc_SrpSetParams(srv, N, sizeof(N),
g, sizeof(g), g, sizeof(g),

View File

@@ -137,6 +137,8 @@ typedef struct Srp {
* @return 0 on success, {@literal <} 0 on error. @see error-crypt.h * @return 0 on success, {@literal <} 0 on error. @see error-crypt.h
*/ */
WOLFSSL_API int wc_SrpInit(Srp* srp, SrpType type, SrpSide side); WOLFSSL_API int wc_SrpInit(Srp* srp, SrpType type, SrpSide side);
WOLFSSL_API int wc_SrpInit_ex(Srp* srp, SrpType type, SrpSide side,
void* heap, int devId);
/** /**
* Releases the Srp struct resources after usage. * Releases the Srp struct resources after usage.