diff --git a/wolfcrypt/src/asn.c b/wolfcrypt/src/asn.c index 8b92f6007..00dfbb702 100644 --- a/wolfcrypt/src/asn.c +++ b/wolfcrypt/src/asn.c @@ -3774,43 +3774,75 @@ static int GetAlgoV2(int encAlgId, const byte** oid, int *len, int* id, return ret; } -/* Converts Encrypted PKCS#8 to 'traditional' (i.e. PKCS#8 removed from - * decrypted key.) +/* PKCS#8 encryption from RFC 5208 + * This function takes in an unencrypted PKCS#8 DER key and converts it to + * PKCS#8 encrypted format. The resulting encrypted key can be decrypted using + * wc_DecryptPKCS8Key. + * + * EncryptedPrivateKeyInfo ::= SEQUENCE { + * encryptionAlgorithm EncryptionAlgorithmIdentifier, + * encryptedData EncryptedData } + * EncryptionAlgorithmIdentifier ::= AlgorithmIdentifier + * EncryptedData ::= OCTET STRING + * + * key DER buffer containing the unencrypted PKCS#8 key. + * keySz The size of the key buffer. + * out The buffer to place the encrypted key in. + * outSz The size of the out buffer. + * password The password to use for the password-based encryption algorithm. + * passwordSz The length of the password (not including the NULL terminator). + * vPKCS The PKCS version to use. Can be 1 for PKCS12 or PKCS5. + * pbeOid The OID of the PBE scheme to use (e.g. PBES2 or one of the OIDs + for PBES1 in RFC 2898 A.3) + * encAlgId The encryption algorithm ID to use (e.g. AES256CBCb). + * salt The salt buffer to use. If NULL, a random salt will be used. + * saltSz The length of the salt buffer. Can be 0 if passing NULL for salt. + * itt The number of iterations to use for the KDF. + * rng A pointer to an initialized WC_RNG object. + * heap A pointer to the heap use for dynamic allocation. Can be NULL. + * + * Returns the size of the encrypted key placed in out. In error cases, returns + * negative values. */ -int TraditionalEnc(byte* key, word32 keySz, byte* out, word32* outSz, - const char* password, int passwordSz, int vPKCS, int vAlgo, +int wc_EncryptPKCS8Key(byte* key, word32 keySz, byte* out, word32 outSz, + const char* password, int passwordSz, int vPKCS, int pbeOid, int encAlgId, byte* salt, word32 saltSz, int itt, WC_RNG* rng, void* heap) { - int ret = 0; - int version, blockSz, id; - word32 idx = 0, encIdx; #ifdef WOLFSSL_SMALL_STACK byte* saltTmp = NULL; #else byte saltTmp[MAX_SALT_SIZE]; #endif - byte cbcIv[MAX_IV_SIZE]; - byte *pkcs8Key = NULL; - word32 pkcs8KeySz = 0, padSz = 0; - int algId = 0; - const byte* curveOid = NULL; - word32 curveOidSz = 0; - const byte* pbeOid = NULL; - word32 pbeOidSz = 0; + int ret = 0; + int version = 0; + int pbeId = 0; + int blockSz = 0; const byte* encOid = NULL; int encOidSz = 0; - word32 pbeLen = 0, kdfLen = 0, encLen = 0; - word32 innerLen = 0, outerLen; + word32 padSz = 0; + word32 innerLen = 0; + word32 outerLen = 0; + const byte* pbeOidBuf = NULL; + word32 pbeOidBufSz = 0; + word32 pbeLen = 0; + word32 kdfLen = 0; + word32 encLen = 0; + byte cbcIv[MAX_IV_SIZE]; + word32 idx = 0; + word32 encIdx = 0; - ret = CheckAlgo(vPKCS, vAlgo, &id, &version, &blockSz); - /* create random salt if one not provided */ + (void)heap; + + WOLFSSL_ENTER("wc_EncryptPKCS8Key"); + + ret = CheckAlgo(vPKCS, pbeOid, &pbeId, &version, &blockSz); if (ret == 0 && (salt == NULL || saltSz == 0)) { saltSz = 8; #ifdef WOLFSSL_SMALL_STACK saltTmp = (byte*)XMALLOC(saltSz, heap, DYNAMIC_TYPE_TMP_BUFFER); if (saltTmp == NULL) - return MEMORY_E; + ret = MEMORY_E; #endif salt = saltTmp; @@ -3819,52 +3851,25 @@ int TraditionalEnc(byte* key, word32 keySz, byte* out, word32* outSz, #ifdef WOLFSSL_SMALL_STACK XFREE(saltTmp, heap, DYNAMIC_TYPE_TMP_BUFFER); #endif - return ret; - } - } - - if (ret == 0) { - /* check key type and get OID if ECC */ - ret = wc_GetKeyOID(key, keySz, &curveOid, &curveOidSz, &algId, heap); - if (ret == 1) - ret = 0; - } - if (ret == 0) { - ret = wc_CreatePKCS8Key(NULL, &pkcs8KeySz, key, keySz, algId, curveOid, - curveOidSz); - if (ret == LENGTH_ONLY_E) - ret = 0; - } - if (ret == 0) { - pkcs8Key = (byte*)XMALLOC(pkcs8KeySz, NULL, DYNAMIC_TYPE_TMP_BUFFER); - if (pkcs8Key == NULL) - ret = MEMORY_E; - } - if (ret == 0) { - ret = wc_CreatePKCS8Key(pkcs8Key, &pkcs8KeySz, key, keySz, algId, - curveOid, curveOidSz); - if (ret >= 0) { - pkcs8KeySz = ret; - ret = 0; } } if (ret == 0 && version == PKCS5v2) - ret = GetAlgoV2(encAlgId, &encOid, &encOidSz, &id, &blockSz); + ret = GetAlgoV2(encAlgId, &encOid, &encOidSz, &pbeId, &blockSz); if (ret == 0) { - padSz = (blockSz - (pkcs8KeySz & (blockSz - 1))) & (blockSz - 1); + padSz = (blockSz - (keySz & (blockSz - 1))) & (blockSz - 1); /* inner = OCT salt INT itt */ innerLen = 2 + saltSz + 2 + (itt < 256 ? 1 : 2); if (version != PKCS5v2) { - pbeOid = OidFromId(id, oidPBEType, &pbeOidSz); + pbeOidBuf = OidFromId(pbeId, oidPBEType, &pbeOidBufSz); /* pbe = OBJ pbse1 SEQ [ inner ] */ - pbeLen = 2 + pbeOidSz + 2 + innerLen; + pbeLen = 2 + pbeOidBufSz + 2 + innerLen; } else { - pbeOid = pbes2; - pbeOidSz = sizeof(pbes2); + pbeOidBuf = pbes2; + pbeOidBufSz = sizeof(pbes2); /* kdf = OBJ pbkdf2 [ SEQ innerLen ] */ kdfLen = 2 + sizeof(pbkdf2Oid) + 2 + innerLen; /* enc = OBJ enc_alg OCT iv */ @@ -3878,35 +3883,35 @@ int TraditionalEnc(byte* key, word32 keySz, byte* out, word32* outSz, if (ret == 0) { /* outer = SEQ [ pbe ] OCT encrypted_PKCS#8_key */ outerLen = 2 + pbeLen; - outerLen += SetOctetString(pkcs8KeySz + padSz, out); - outerLen += pkcs8KeySz + padSz; + outerLen += SetOctetString(keySz + padSz, out); + outerLen += keySz + padSz; idx += SetSequence(outerLen, out + idx); - encIdx = idx + outerLen - pkcs8KeySz - padSz; + encIdx = idx + outerLen - keySz - padSz; /* Put Encrypted content in place. */ - XMEMCPY(out + encIdx, pkcs8Key, pkcs8KeySz); + XMEMCPY(out + encIdx, key, keySz); if (padSz > 0) { - XMEMSET(out + encIdx + pkcs8KeySz, padSz, padSz); - pkcs8KeySz += padSz; + XMEMSET(out + encIdx + keySz, padSz, padSz); + keySz += padSz; } - ret = wc_CryptKey(password, passwordSz, salt, saltSz, itt, id, - out + encIdx, pkcs8KeySz, version, cbcIv, 1, 0); + ret = wc_CryptKey(password, passwordSz, salt, saltSz, itt, pbeId, + out + encIdx, keySz, version, cbcIv, 1, 0); } if (ret == 0) { if (version != PKCS5v2) { /* PBE algorithm */ idx += SetSequence(pbeLen, out + idx); - idx += SetObjectId(pbeOidSz, out + idx); - XMEMCPY(out + idx, pbeOid, pbeOidSz); - idx += pbeOidSz; + idx += SetObjectId(pbeOidBufSz, out + idx); + XMEMCPY(out + idx, pbeOidBuf, pbeOidBufSz); + idx += pbeOidBufSz; } else { /* PBES2 algorithm identifier */ idx += SetSequence(pbeLen, out + idx); - idx += SetObjectId(pbeOidSz, out + idx); - XMEMCPY(out + idx, pbeOid, pbeOidSz); - idx += pbeOidSz; + idx += SetObjectId(pbeOidBufSz, out + idx); + XMEMCPY(out + idx, pbeOidBuf, pbeOidBufSz); + idx += pbeOidBufSz; /* PBES2 Parameters: SEQ [ kdf ] SEQ [ enc ] */ idx += SetSequence(2 + kdfLen + 2 + encLen, out + idx); /* KDF Algorithm Identifier */ @@ -3918,7 +3923,7 @@ int TraditionalEnc(byte* key, word32 keySz, byte* out, word32* outSz, idx += SetSequence(innerLen, out + idx); idx += SetOctetString(saltSz, out + idx); XMEMCPY(out + idx, salt, saltSz); idx += saltSz; - ret = SetShortInt(out, &idx, itt, *outSz); + ret = SetShortInt(out, &idx, itt, outSz); if (ret > 0) ret = 0; } @@ -3934,23 +3939,115 @@ int TraditionalEnc(byte* key, word32 keySz, byte* out, word32* outSz, XMEMCPY(out + idx, cbcIv, blockSz); idx += blockSz; } - idx += SetOctetString(pkcs8KeySz, out + idx); + idx += SetOctetString(keySz, out + idx); /* Default PRF - no need to write out OID */ - idx += pkcs8KeySz; + idx += keySz; ret = idx; } - if (pkcs8Key != NULL) { - ForceZero(pkcs8Key, pkcs8KeySz); - XFREE(pkcs8Key, NULL, DYNAMIC_TYPE_TMP_BUFFER); - } #ifdef WOLFSSL_SMALL_STACK if (saltTmp != NULL) { XFREE(saltTmp, heap, DYNAMIC_TYPE_TMP_BUFFER); } #endif + WOLFSSL_LEAVE("wc_EncryptPKCS8Key", ret); + + return ret; +} + +/* PKCS#8 decryption from RFC 5208 + * + * NOTE: input buffer is overwritten with decrypted data! + * + * This function takes an encrypted PKCS#8 DER key and decrypts it to PKCS#8 + * unencrypted DER. Undoes the encryption done by wc_EncryptPKCS8Key. Returns + * the length of the decrypted buffer or a negative value if there was an error. + */ +int wc_DecryptPKCS8Key(byte* input, word32 sz, const char* password, + int passwordSz) +{ + int ret; + int length; + word32 inOutIdx = 0; + + if (GetSequence(input, &inOutIdx, &length, sz) < 0) { + ret = ASN_PARSE_E; + } + else { + ret = DecryptContent(input + inOutIdx, sz - inOutIdx, password, + passwordSz); + if (ret > 0) { + XMEMMOVE(input, input + inOutIdx, ret); + } + } + + if (ret > 0) { + /* DecryptContent will decrypt the data, but it will leave any padding + * bytes intact. This code calculates the length without the padding + * and we return that to the user. */ + inOutIdx = 0; + if (GetSequence(input, &inOutIdx, &length, ret) < 0) { + ret = ASN_PARSE_E; + } + else { + ret = inOutIdx + length; + } + } + + return ret; +} + +/* Takes an unencrypted, traditional DER-encoded key and converts it to a PKCS#8 + * encrypted key. */ +int TraditionalEnc(byte* key, word32 keySz, byte* out, word32* outSz, + const char* password, int passwordSz, int vPKCS, int vAlgo, + int encAlgId, byte* salt, word32 saltSz, int itt, WC_RNG* rng, + void* heap) +{ + int ret = 0; + byte *pkcs8Key = NULL; + word32 pkcs8KeySz = 0; + int algId = 0; + const byte* curveOid = NULL; + word32 curveOidSz = 0; + + if (ret == 0) { + /* check key type and get OID if ECC */ + ret = wc_GetKeyOID(key, keySz, &curveOid, &curveOidSz, &algId, heap); + if (ret == 1) + ret = 0; + } + if (ret == 0) { + ret = wc_CreatePKCS8Key(NULL, &pkcs8KeySz, key, keySz, algId, curveOid, + curveOidSz); + if (ret == LENGTH_ONLY_E) + ret = 0; + } + if (ret == 0) { + pkcs8Key = (byte*)XMALLOC(pkcs8KeySz, heap, DYNAMIC_TYPE_TMP_BUFFER); + if (pkcs8Key == NULL) + ret = MEMORY_E; + } + if (ret == 0) { + ret = wc_CreatePKCS8Key(pkcs8Key, &pkcs8KeySz, key, keySz, algId, + curveOid, curveOidSz); + if (ret >= 0) { + pkcs8KeySz = ret; + ret = 0; + } + } + if (ret == 0) { + ret = wc_EncryptPKCS8Key(pkcs8Key, pkcs8KeySz, out, *outSz, password, + passwordSz, vPKCS, vAlgo, encAlgId, salt, saltSz, itt, rng, heap); + } + + if (pkcs8Key != NULL) { + ForceZero(pkcs8Key, pkcs8KeySz); + XFREE(pkcs8Key, heap, DYNAMIC_TYPE_TMP_BUFFER); + } + (void)rng; return ret; @@ -4127,25 +4224,16 @@ exit_dc: return ret; } - /* Remove Encrypted PKCS8 header, move beginning of traditional to beginning of input */ -int ToTraditionalEnc(byte* input, word32 sz,const char* password, +int ToTraditionalEnc(byte* input, word32 sz, const char* password, int passwordSz, word32* algId) { - int ret, length; - word32 inOutIdx = 0; + int ret; - if (GetSequence(input, &inOutIdx, &length, sz) < 0) { - ret = ASN_PARSE_E; - } - else { - ret = DecryptContent(input + inOutIdx, sz - inOutIdx, password, - passwordSz); - if (ret > 0) { - XMEMMOVE(input, input + inOutIdx, ret); - ret = ToTraditional_ex(input, ret, algId); - } + ret = wc_DecryptPKCS8Key(input, sz, password, passwordSz); + if (ret > 0) { + ret = ToTraditional_ex(input, ret, algId); } return ret; diff --git a/wolfssl/wolfcrypt/asn_public.h b/wolfssl/wolfcrypt/asn_public.h index ed1775d91..289f3620e 100644 --- a/wolfssl/wolfcrypt/asn_public.h +++ b/wolfssl/wolfcrypt/asn_public.h @@ -583,6 +583,9 @@ WOLFSSL_API int wc_GetPkcs8TraditionalOffset(byte* input, word32* inOutIdx, word32 sz); WOLFSSL_API int wc_CreatePKCS8Key(byte* out, word32* outSz, byte* key, word32 keySz, int algoID, const byte* curveOID, word32 oidSz); +WOLFSSL_API int wc_EncryptPKCS8Key(byte*, word32, byte*, word32, const char*, + int, int, int, int, byte*, word32, int, WC_RNG*, void*); +WOLFSSL_API int wc_DecryptPKCS8Key(byte*, word32, const char*, int); #ifndef NO_ASN_TIME /* Time */