diff --git a/src/ssl.c b/src/ssl.c index 407b01fe7..4b5683f6c 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -20541,6 +20541,9 @@ WOLFSSL_DH *wolfSSL_d2i_DHparams(WOLFSSL_DH **dh, const unsigned char **pp, } #endif /* !HAVE_FIPS || HAVE_FIPS_VERSION > 2 */ +#define ASN_LEN_SIZE(l) \ + (((l) < 128) ? 1 : (((l) < 256) ? 2 : 3)) + /* Converts internal WOLFSSL_DH structure to DER encoded DH. * * dh : structure to copy DH parameters from. @@ -20552,6 +20555,8 @@ int wolfSSL_i2d_DHparams(const WOLFSSL_DH *dh, unsigned char **out) { word32 len; int ret = 0; + int pSz; + int gSz; WOLFSSL_ENTER("wolfSSL_i2d_DHparams"); @@ -20561,15 +20566,17 @@ int wolfSSL_i2d_DHparams(const WOLFSSL_DH *dh, unsigned char **out) } /* Get total length */ - len = 2 + mp_leading_bit((mp_int*)dh->p->internal) + - mp_unsigned_bin_size((mp_int*)dh->p->internal) + - 2 + mp_leading_bit((mp_int*)dh->g->internal) + - mp_unsigned_bin_size((mp_int*)dh->g->internal); + pSz = mp_unsigned_bin_size((mp_int*)dh->p->internal); + gSz = mp_unsigned_bin_size((mp_int*)dh->g->internal); + len = 1 + ASN_LEN_SIZE(pSz) + mp_leading_bit((mp_int*)dh->p->internal) + + pSz + + 1 + ASN_LEN_SIZE(gSz) + mp_leading_bit((mp_int*)dh->g->internal) + + gSz; /* Two bytes required for length if ASN.1 SEQ data greater than 127 bytes * and less than 256 bytes. */ - len = ((len > 127) ? 2 : 1) + len; + len += 1 + ASN_LEN_SIZE(len); if (out != NULL && *out != NULL) { ret = StoreDHparams(*out, &len, (mp_int*)dh->p->internal, diff --git a/tests/api.c b/tests/api.c index 39eec6d39..b28b08a41 100644 --- a/tests/api.c +++ b/tests/api.c @@ -37995,9 +37995,11 @@ static void test_wolfSSL_i2d_DHparams(void) AssertIntEQ(DH_generate_key(dh), 1); AssertIntEQ(wolfSSL_i2d_DHparams(dh, &pt2), 268); - /* Invalid cases */ + /* Invalid case */ AssertIntEQ(wolfSSL_i2d_DHparams(NULL, &pt2), 0); - AssertIntEQ(wolfSSL_i2d_DHparams(dh, NULL), 264); + + /* Return length only */ + AssertIntEQ(wolfSSL_i2d_DHparams(dh, NULL), 268); DH_free(dh); printf(resultFmt, passed); @@ -38019,9 +38021,11 @@ static void test_wolfSSL_i2d_DHparams(void) AssertIntEQ(DH_generate_key(dh), 1); AssertIntEQ(wolfSSL_i2d_DHparams(dh, &pt2), 396); - /* Invalid cases */ + /* Invalid case */ AssertIntEQ(wolfSSL_i2d_DHparams(NULL, &pt2), 0); - AssertIntEQ(wolfSSL_i2d_DHparams(dh, NULL), 392); + + /* Return length only */ + AssertIntEQ(wolfSSL_i2d_DHparams(dh, NULL), 396); DH_free(dh); printf(resultFmt, passed); diff --git a/wolfcrypt/src/asn.c b/wolfcrypt/src/asn.c index d3d096d98..fdc363d9c 100644 --- a/wolfcrypt/src/asn.c +++ b/wolfcrypt/src/asn.c @@ -16093,47 +16093,36 @@ int EncodePolicyOID(byte *out, word32 *outSz, const char *in, void* heap) int StoreDHparams(byte* out, word32* outLen, mp_int* p, mp_int* g) { word32 idx = 0; - int pSz; - int gSz; - unsigned int tmp; - word32 headerSz = 4; /* 2*ASN_TAG + 2*LEN(ENUM) */ - - /* If the leading bit on the INTEGER is a 1, add a leading zero */ - int pLeadingZero = mp_leading_bit(p); - int gLeadingZero = mp_leading_bit(g); - int pLen = mp_unsigned_bin_size(p); - int gLen = mp_unsigned_bin_size(g); + word32 total; WOLFSSL_ENTER("StoreDHparams"); + if (out == NULL) { WOLFSSL_MSG("Null buffer error"); return BUFFER_E; } - tmp = pLeadingZero + gLeadingZero + pLen + gLen; - if (*outLen < (tmp + headerSz)) { + /* determine size */ + /* integer - g */ + idx = SetASNIntMP(g, -1, NULL); + /* integer - p */ + idx += SetASNIntMP(p, -1, NULL); + total = idx; + /* sequence */ + idx += SetSequence(idx, NULL); + + /* make sure output fits in buffer */ + if (idx > *outLen) { return BUFFER_E; } - /* Set sequence */ - idx = SetSequence(tmp + headerSz + 2, out); - - /* Encode p */ - pSz = SetASNIntMP(p, -1, &out[idx]); - if (pSz < 0) { - WOLFSSL_MSG("SetASNIntMP failed"); - return pSz; - } - idx += pSz; - - /* Encode g */ - gSz = SetASNIntMP(g, -1, &out[idx]); - if (gSz < 0) { - WOLFSSL_MSG("SetASNIntMP failed"); - return gSz; - } - idx += gSz; - + /* write DH parameters */ + /* sequence - for P and G only */ + idx = SetSequence(total, out); + /* integer - p */ + idx += SetASNIntMP(p, -1, out + idx); + /* integer - g */ + idx += SetASNIntMP(g, -1, out + idx); *outLen = idx; return 0;