fix for lenght of PKCS8 with ECC and for ECC get key algo ID

This commit is contained in:
Jacob Barthelmeh
2017-03-28 10:36:33 -06:00
parent 72d11e19cd
commit 219fb584e2

View File

@@ -2012,7 +2012,7 @@ int wc_CreatePKCS8Key(byte* out, word32* outSz, byte* key, word32 keySz,
sz = SetLength(oidSz, out + keyIdx); sz = SetLength(oidSz, out + keyIdx);
keyIdx += sz; tmpSz += sz; keyIdx += sz; tmpSz += sz;
XMEMCPY(out + keyIdx, curveOID, oidSz); XMEMCPY(out + keyIdx, curveOID, oidSz);
keyIdx += oidSz; tmpSz += keyIdx; keyIdx += oidSz; tmpSz += oidSz;
} }
out[keyIdx] = ASN_OCTET_STRING; out[keyIdx] = ASN_OCTET_STRING;
@@ -2042,10 +2042,16 @@ int wc_GetKeyOID(byte* key, word32 keySz, const byte** curveOID, word32* oidSz,
#ifdef HAVE_ECC #ifdef HAVE_ECC
ecc_key ecc; ecc_key ecc;
#endif #endif
#ifndef NO_RSA #ifndef NO_RSA
RsaKey rsa; RsaKey rsa;
#endif
if (algoID == NULL) {
return BAD_FUNC_ARG;
}
*algoID = 0;
#ifndef NO_RSA
wc_InitRsaKey(&rsa, heap); wc_InitRsaKey(&rsa, heap);
if (wc_RsaPrivateKeyDecode(key, &tmpIdx, &rsa, keySz) == 0) { if (wc_RsaPrivateKeyDecode(key, &tmpIdx, &rsa, keySz) == 0) {
*algoID = RSAk; *algoID = RSAk;
@@ -2056,12 +2062,19 @@ int wc_GetKeyOID(byte* key, word32 keySz, const byte** curveOID, word32* oidSz,
wc_FreeRsaKey(&rsa); wc_FreeRsaKey(&rsa);
#endif /* NO_RSA */ #endif /* NO_RSA */
#ifdef HAVE_ECC #ifdef HAVE_ECC
if (algoID == 0) { if (*algoID != RSAk) {
tmpIdx = 0; tmpIdx = 0;
wc_ecc_init_ex(&ecc, heap, INVALID_DEVID); wc_ecc_init_ex(&ecc, heap, INVALID_DEVID);
if (wc_EccPrivateKeyDecode(key, &tmpIdx, &ecc, keySz) == 0) { if (wc_EccPrivateKeyDecode(key, &tmpIdx, &ecc, keySz) == 0) {
*algoID = ECDSAk; *algoID = ECDSAk;
/* sanity check on arguments */
if (curveOID == NULL || oidSz == NULL) {
WOLFSSL_MSG("Error getting ECC curve OID");
wc_ecc_free(&ecc);
return BAD_FUNC_ARG;
}
/* now find oid */ /* now find oid */
if (wc_ecc_get_oid(ecc.dp->oidSum, curveOID, oidSz) < 0) { if (wc_ecc_get_oid(ecc.dp->oidSum, curveOID, oidSz) < 0) {
WOLFSSL_MSG("Error getting ECC curve OID"); WOLFSSL_MSG("Error getting ECC curve OID");