Get PSS going on server side

This commit is contained in:
Sean Parkinson
2017-05-12 11:33:52 +10:00
parent 2f15d57a6f
commit 9fb6373cfb
6 changed files with 225 additions and 77 deletions

View File

@ -2732,8 +2732,41 @@ void FreeX509(WOLFSSL_X509* x509)
#ifndef NO_RSA
#if defined(WOLFSSL_TLS13) && defined(WC_RSA_PSS)
static int ConvertHashPss(int hashAlgo, enum wc_HashType* hashType, int* mgf) {
switch (hashAlgo) {
#ifdef WOLFSSL_SHA512
case sha512_mac:
*hashType = WC_HASH_TYPE_SHA512;
if (mgf != NULL)
*mgf = WC_MGF1SHA512;
break;
#endif
#ifdef WOLFSSL_SHA384
case sha384_mac:
*hashType = WC_HASH_TYPE_SHA384;
if (mgf != NULL)
*mgf = WC_MGF1SHA384;
break;
#endif
#ifndef NO_SHA256
case sha256_mac:
*hashType = WC_HASH_TYPE_SHA256;
if (mgf != NULL)
*mgf = WC_MGF1SHA256;
break;
#endif
default:
return BAD_FUNC_ARG;
}
return 0;
}
#endif
int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, byte* out,
word32* outSz, RsaKey* key, const byte* keyBuf, word32 keySz, void* ctx)
word32* outSz, int sigAlgo, int hashAlgo, RsaKey* key,
const byte* keyBuf, word32 keySz, void* ctx)
{
int ret;
@ -2741,6 +2774,8 @@ int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, byte* out,
(void)keyBuf;
(void)keySz;
(void)ctx;
(void)sigAlgo;
(void)hashAlgo;
WOLFSSL_ENTER("RsaSign");
@ -2752,6 +2787,19 @@ int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, byte* out,
else
#endif /*HAVE_PK_CALLBACKS */
{
#if defined(WOLFSSL_TLS13) && defined(WC_RSA_PSS)
if (sigAlgo == rsa_pss_sa_algo) {
enum wc_HashType hashType = WC_HASH_TYPE_NONE;
int mgf = 0;
ret = ConvertHashPss(hashAlgo, &hashType, &mgf);
if (ret != 0)
return ret;
ret = wc_RsaPSS_Sign(in, inSz, out, *outSz, hashType, mgf, key,
ssl->rng);
}
else
#endif
ret = wc_RsaSSL_Sign(in, inSz, out, *outSz, key, ssl->rng);
}
@ -2795,35 +2843,17 @@ int RsaVerify(WOLFSSL* ssl, byte* in, word32 inSz, byte** out, int sigAlgo,
else
#endif /*HAVE_PK_CALLBACKS */
{
#ifdef WOLFSSL_TLS13
#ifdef WC_RSA_PSS
#if defined(WOLFSSL_TLS13) && defined(WC_RSA_PSS)
if (sigAlgo == rsa_pss_sa_algo) {
enum wc_HashType hashType = WC_HASH_TYPE_NONE;
int mgf = 0;
switch (hashAlgo) {
case sha512_mac:
#ifdef WOLFSSL_SHA512
hashType = WC_HASH_TYPE_SHA512;
mgf = WC_MGF1SHA512;
#endif
break;
case sha384_mac:
#ifdef WOLFSSL_SHA384
hashType = WC_HASH_TYPE_SHA384;
mgf = WC_MGF1SHA384;
#endif
break;
case sha256_mac:
#ifndef NO_SHA256
hashType = WC_HASH_TYPE_SHA256;
mgf = WC_MGF1SHA256;
#endif
break;
}
ret = ConvertHashPss(hashAlgo, &hashType, &mgf);
if (ret != 0)
return ret;
ret = wc_RsaPSS_VerifyInline(in, inSz, out, hashType, mgf, key);
}
else
#endif
#endif
ret = wc_RsaSSL_VerifyInline(in, inSz, out, key);
}
@ -2840,14 +2870,40 @@ int RsaVerify(WOLFSSL* ssl, byte* in, word32 inSz, byte** out, int sigAlgo,
return ret;
}
#if defined(WOLFSSL_TLS13) && defined(WC_RSA_PSS)
int CheckRsaPssPadding(const byte* plain, word32 plainSz, byte* out,
word32 sigSz, enum wc_HashType hashType)
{
int ret;
if (plainSz != sigSz || out == NULL)
ret = VERIFY_CERT_ERROR;
else {
out -= 2 * sigSz;
XMEMCPY(out, plain, plainSz);
out -= 8;
XMEMSET(out, 0, 8);
wc_Hash(hashType, out, 8 + plainSz * 2, out, plainSz);
if (XMEMCMP(out, out + 8 + plainSz * 2, plainSz) != 0)
ret = VERIFY_CERT_ERROR;
else
ret = 0;
}
return ret;
}
#endif
/* Verify RSA signature, 0 on success */
int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, word32 sigSz,
const byte* plain, word32 plainSz, RsaKey* key)
const byte* plain, word32 plainSz, int sigAlgo, int hashAlgo, RsaKey* key)
{
byte* out = NULL; /* inline result */
int ret;
(void)ssl;
(void)sigAlgo;
(void)hashAlgo;
WOLFSSL_ENTER("VerifyRsaSign");
@ -2860,8 +2916,23 @@ int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, word32 sigSz,
return BUFFER_E;
}
ret = wc_RsaSSL_VerifyInline(verifySig, sigSz, &out, key);
#if defined(WOLFSSL_TLS13) && defined(WC_RSA_PSS)
if (sigAlgo == rsa_pss_sa_algo) {
enum wc_HashType hashType = WC_HASH_TYPE_NONE;
int mgf = 0;
ret = ConvertHashPss(hashAlgo, &hashType, &mgf);
if (ret != 0)
return ret;
ret = wc_RsaPSS_VerifyInline(verifySig, sigSz, &out, hashType, mgf,
key);
if (ret > 0)
ret = CheckRsaPssPadding(plain, plainSz, out, ret, hashType);
}
else
#endif
{
ret = wc_RsaSSL_VerifyInline(verifySig, sigSz, &out, key);
if (ret > 0) {
if (ret != (int)plainSz || !out ||
XMEMCMP(plain, out, plainSz) != 0) {
@ -2871,6 +2942,7 @@ int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, word32 sigSz,
ret = 0; /* RSA reset */
}
}
}
/* Handle async pending response */
#if defined(WOLFSSL_ASYNC_CRYPT)
@ -18216,7 +18288,7 @@ int SendCertificateVerify(WOLFSSL* ssl)
ret = RsaSign(ssl,
ssl->buffers.sig.buffer, ssl->buffers.sig.length,
args->verify + args->extraSz + VERIFY_HEADER, &args->sigSz,
key,
rsa_sa_algo, no_mac, key,
ssl->buffers.key->buffer,
ssl->buffers.key->length,
#ifdef HAVE_PK_CALLBACKS
@ -18271,7 +18343,7 @@ int SendCertificateVerify(WOLFSSL* ssl)
ret = VerifyRsaSign(ssl,
args->verifySig, args->sigSz,
ssl->buffers.sig.buffer, ssl->buffers.sig.length,
key
rsa_sa_algo, no_mac, key
);
}
#endif /* !NO_RSA */
@ -19816,7 +19888,7 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
ssl->buffers.sig.length,
args->output + args->idx,
&args->sigSz,
key,
rsa_sa_algo, no_mac, key,
ssl->buffers.key->buffer,
ssl->buffers.key->length,
#ifdef HAVE_PK_CALLBACKS
@ -19872,7 +19944,7 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
ssl->buffers.sig.length,
args->output + args->idx,
&args->sigSz,
key,
rsa_sa_algo, no_mac, key,
ssl->buffers.key->buffer,
ssl->buffers.key->length,
#ifdef HAVE_PK_CALLBACKS
@ -19955,7 +20027,7 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
args->verifySig, args->sigSz,
ssl->buffers.sig.buffer,
ssl->buffers.sig.length,
key
rsa_sa_algo, no_mac, key
);
break;
}
@ -20010,7 +20082,7 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
args->verifySig, args->sigSz,
ssl->buffers.sig.buffer,
ssl->buffers.sig.length,
key
rsa_sa_algo, no_mac, key
);
break;
}

View File

@ -4566,7 +4566,7 @@ static word16 TLSX_SignatureAlgorithms_Write(byte* data, byte* output)
static int TLSX_SignatureAlgorithms_Parse(WOLFSSL *ssl, byte* input,
word16 length)
{
int ret = 0;
int i;
word16 len;
(void)ssl;
@ -4581,9 +4581,13 @@ static int TLSX_SignatureAlgorithms_Parse(WOLFSSL *ssl, byte* input,
if (length != OPAQUE16_LEN + len)
return BUFFER_ERROR;
/* Ignore for now. */
ssl->pssAlgo = 0;
for (i = 0; i < len; i += 2) {
if (input[i] == 0x08 && input[i + 1] <= 0x06)
ssl->pssAlgo |= 1 << input[i + 1];
}
return ret;
return 0;
}
/* Sets a new SupportedVersions extension into the extension list.

View File

@ -2882,18 +2882,24 @@ static INLINE void EncodeSigAlg(byte hashAlgo, byte hsType, byte* output)
{
switch (hsType) {
#ifdef HAVE_ECC
case DYNAMIC_TYPE_ECC:
case ecc_dsa_sa_algo:
output[0] = hashAlgo;
output[1] = ecc_dsa_sa_algo;
break;
#endif
#ifndef NO_RSA
case DYNAMIC_TYPE_RSA:
case rsa_sa_algo:
output[0] = hashAlgo;
output[1] = rsa_sa_algo;
break;
#endif
#ifdef WC_RSA_PSS
/* PSS signatures: 0x080[4-6] */
case rsa_pss_sa_algo:
output[0] = rsa_pss_sa_algo;
output[1] = hashAlgo;
break;
#endif
#endif
/* ED25519: 0x0807 */
/* ED448: 0x0808 */
}
@ -2908,6 +2914,7 @@ static INLINE void EncodeSigAlg(byte hashAlgo, byte hsType, byte* output)
static INLINE void DecodeSigAlg(byte* input, byte* hashAlgo, byte* hsType)
{
switch (input[0]) {
#ifdef WC_RSA_PSS
case 0x08:
/* PSS signatures: 0x080[4-6] */
if (input[1] <= 0x06) {
@ -2915,6 +2922,7 @@ static INLINE void DecodeSigAlg(byte* input, byte* hashAlgo, byte* hsType)
*hashAlgo = input[1];
}
break;
#endif
/* ED25519: 0x0807 */
/* ED448: 0x0808 */
default:
@ -3014,12 +3022,22 @@ static void CreateSigData(WOLFSSL* ssl, byte* sigData, word16* sigDataSz,
* returns the length of the encoded signature or negative on error.
*/
static int CreateRSAEncodedSig(byte* sig, byte* sigData, int sigDataSz,
int hashAlgo)
int sigAlgo, int hashAlgo)
{
Digest digest;
int hashSz = 0;
int hashOid = 0;
int ret = BAD_FUNC_ARG;
byte* hash;
(void)sigAlgo;
#ifdef WC_RSA_PSS
if (sigAlgo == rsa_pss_sa_algo)
hash = sig;
else
#endif
hash = sigData;
/* Digest the signature data. */
switch (hashAlgo) {
@ -3029,7 +3047,7 @@ static int CreateRSAEncodedSig(byte* sig, byte* sigData, int sigDataSz,
if (ret == 0) {
ret = wc_Sha256Update(&digest.sha256, sigData, sigDataSz);
if (ret == 0)
ret = wc_Sha256Final(&digest.sha256, sigData);
ret = wc_Sha256Final(&digest.sha256, hash);
wc_Sha256Free(&digest.sha256);
}
hashSz = SHA256_DIGEST_SIZE;
@ -3042,7 +3060,7 @@ static int CreateRSAEncodedSig(byte* sig, byte* sigData, int sigDataSz,
if (ret == 0) {
ret = wc_Sha384Update(&digest.sha384, sigData, sigDataSz);
if (ret == 0)
ret = wc_Sha384Final(&digest.sha384, sigData);
ret = wc_Sha384Final(&digest.sha384, hash);
wc_Sha384Free(&digest.sha384);
}
hashSz = SHA384_DIGEST_SIZE;
@ -3055,7 +3073,7 @@ static int CreateRSAEncodedSig(byte* sig, byte* sigData, int sigDataSz,
if (ret == 0) {
ret = wc_Sha512Update(&digest.sha512, sigData, sigDataSz);
if (ret == 0)
ret = wc_Sha512Final(&digest.sha512, sigData);
ret = wc_Sha512Final(&digest.sha512, hash);
wc_Sha512Free(&digest.sha512);
}
hashSz = SHA512_DIGEST_SIZE;
@ -3067,8 +3085,15 @@ static int CreateRSAEncodedSig(byte* sig, byte* sigData, int sigDataSz,
if (ret != 0)
return ret;
#ifdef WC_RSA_PSS
if (sigAlgo == rsa_pss_sa_algo)
return hashSz;
else
#endif
{
/* Encode the signature data as per PKCS #1.5 */
return wc_EncodeSignature(sig, sigData, hashSz, hashOid);
return wc_EncodeSignature(sig, hash, hashSz, hashOid);
}
}
#ifdef HAVE_ECC
@ -3156,7 +3181,39 @@ static int CheckRSASignature(WOLFSSL* ssl, int sigAlgo, int hashAlgo,
#endif
word32 sigSz;
if (sigAlgo == rsa_sa_algo) {
CreateSigData(ssl, sigData, &sigDataSz, 1);
#ifdef WC_RSA_PSS
if (sigAlgo == rsa_pss_sa_algo) {
int hashType = WC_HASH_TYPE_NONE;
switch (hashAlgo) {
case sha512_mac:
#ifdef WOLFSSL_SHA512
hashType = WC_HASH_TYPE_SHA512;
#endif
break;
case sha384_mac:
#ifdef WOLFSSL_SHA384
hashType = WC_HASH_TYPE_SHA384;
#endif
break;
case sha256_mac:
#ifndef NO_SHA256
hashType = WC_HASH_TYPE_SHA256;
#endif
break;
}
ret = sigSz = CreateRSAEncodedSig(sigData, sigData, sigDataSz,
rsa_pss_sa_algo, hashAlgo);
if (ret < 0)
return ret;
ret = CheckRsaPssPadding(sigData, sigSz, decSig, decSigSz, hashType);
}
else
#endif
{
#ifdef WOLFSSL_SMALL_STACK
encodedSig = (byte*)XMALLOC(MAX_ENCODED_SIG_SZ, ssl->heap,
DYNAMIC_TYPE_TMP_BUFFER);
@ -3166,29 +3223,14 @@ static int CheckRSASignature(WOLFSSL* ssl, int sigAlgo, int hashAlgo,
}
#endif
CreateSigData(ssl, sigData, &sigDataSz, 1);
sigSz = CreateRSAEncodedSig(encodedSig, sigData, sigDataSz, hashAlgo);
sigSz = CreateRSAEncodedSig(encodedSig, sigData, sigDataSz,
DYNAMIC_TYPE_RSA, hashAlgo);
/* Check the encoded and decrypted signature data match. */
if (decSigSz != sigSz || decSig == NULL ||
XMEMCMP(decSig, encodedSig, sigSz) != 0) {
ret = VERIFY_CERT_ERROR;
}
}
else {
CreateSigData(ssl, sigData, &sigDataSz, 1);
sigSz = CreateECCEncodedSig(sigData, sigDataSz, hashAlgo);
if (decSigSz != sigSz || decSig == NULL)
ret = VERIFY_CERT_ERROR;
else {
decSig -= 2 * decSigSz;
XMEMCPY(decSig, sigData, decSigSz);
decSig -= 8;
XMEMSET(decSig, 0, 8);
CreateECCEncodedSig(decSig, 8 + decSigSz * 2, hashAlgo);
if (XMEMCMP(decSig, decSig + 8 + decSigSz * 2, decSigSz) != 0)
ret = VERIFY_CERT_ERROR;
}
}
#ifdef WOLFSSL_SMALL_STACK
end:
@ -3465,6 +3507,7 @@ typedef struct Scv13Args {
int sendSz;
word16 length;
int sigAlgo;
byte* sigData;
word16 sigDataSz;
} Scv13Args;
@ -3570,7 +3613,17 @@ int SendTls13CertificateVerify(WOLFSSL* ssl)
goto exit_scv;
/* Add signature algorithm. */
EncodeSigAlg(ssl->suites->hashAlgo, ssl->hsType, args->verify);
if (ssl->hsType == DYNAMIC_TYPE_RSA) {
#ifdef WC_RSA_PSS
if (ssl->pssAlgo | (1 << ssl->suites->hashAlgo))
args->sigAlgo = rsa_pss_sa_algo;
else
#endif
args->sigAlgo = rsa_sa_algo;
}
else if (ssl->hsType == DYNAMIC_TYPE_ECC)
args->sigAlgo = ecc_dsa_sa_algo;
EncodeSigAlg(ssl->suites->hashAlgo, args->sigAlgo, args->verify);
/* Create the data to be signed. */
args->sigData = (byte*)XMALLOC(MAX_SIG_DATA_SZ, ssl->heap,
@ -3591,9 +3644,8 @@ int SendTls13CertificateVerify(WOLFSSL* ssl)
ERROR_OUT(MEMORY_E, exit_scv);
}
/* Digest the signature data and encode. Used in verify too. */
ret = CreateRSAEncodedSig(sig->buffer, args->sigData,
args->sigDataSz, ssl->suites->hashAlgo);
args->sigDataSz, args->sigAlgo, ssl->suites->hashAlgo);
if (ret < 0)
goto exit_scv;
sig->length = ret;
@ -3645,6 +3697,7 @@ int SendTls13CertificateVerify(WOLFSSL* ssl)
ret = RsaSign(ssl, sig->buffer, sig->length,
args->verify + HASH_SIG_SIZE + VERIFY_HEADER, &args->sigLen,
args->sigAlgo, ssl->suites->hashAlgo,
(RsaKey*)ssl->hsKey,
ssl->buffers.key->buffer, ssl->buffers.key->length,
#ifdef HAVE_PK_CALLBACKS
@ -3690,7 +3743,8 @@ int SendTls13CertificateVerify(WOLFSSL* ssl)
/* check for signature faults */
ret = VerifyRsaSign(ssl, args->verifySig, args->sigLen,
sig->buffer, sig->length, (RsaKey*)ssl->hsKey);
sig->buffer, sig->length, args->sigAlgo,
ssl->suites->hashAlgo, (RsaKey*)ssl->hsKey);
}
#endif /* !NO_RSA */

View File

@ -1587,6 +1587,15 @@ int wc_RsaSSL_Sign(const byte* in, word32 inLen, byte* out, word32 outLen,
WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng);
}
#ifdef WC_RSA_PSS
int wc_RsaPSS_Sign(const byte* in, word32 inLen, byte* out, word32 outLen,
enum wc_HashType hash, int mgf, RsaKey* key, WC_RNG* rng)
{
return RsaPublicEncryptEx(in, inLen, out, outLen, key,
RSA_PRIVATE_ENCRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PSS_PAD,
hash, mgf, NULL, 0, rng);
}
#endif
int wc_RsaEncryptSize(RsaKey* key)
{

View File

@ -3099,6 +3099,7 @@ struct WOLFSSL {
#endif
#ifdef WOLFSSL_TLS13
word16 namedGroup;
byte pssAlgo;
#endif
#ifdef HAVE_NTRU
word16 peerNtruKeyLen;
@ -3411,12 +3412,17 @@ WOLFSSL_LOCAL void ShrinkOutputBuffer(WOLFSSL* ssl);
WOLFSSL_LOCAL int VerifyClientSuite(WOLFSSL* ssl);
#ifndef NO_CERTS
#ifndef NO_RSA
WOLFSSL_LOCAL int CheckRsaPssPadding(const byte* plain, word32 plainSz,
byte* out, word32 sigSz,
enum wc_HashType hashType);
WOLFSSL_LOCAL int VerifyRsaSign(WOLFSSL* ssl,
byte* verifySig, word32 sigSz,
const byte* plain, word32 plainSz,
int sigAlgo, int hashAlgo,
RsaKey* key);
WOLFSSL_LOCAL int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, byte* out,
word32* outSz, RsaKey* key, const byte* keyBuf, word32 keySz, void* ctx);
WOLFSSL_LOCAL int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz,
byte* out, word32* outSz, int sigAlgo, int hashAlgo, RsaKey* key,
const byte* keyBuf, word32 keySz, void* ctx);
WOLFSSL_LOCAL int RsaVerify(WOLFSSL* ssl, byte* in, word32 inSz,
byte** out, int sigAlgo, int hashAlgo, RsaKey* key,
const byte* keyBuf, word32 keySz, void* ctx);

View File

@ -122,6 +122,9 @@ WOLFSSL_API int wc_RsaPrivateDecrypt(const byte* in, word32 inLen, byte* out,
word32 outLen, RsaKey* key);
WOLFSSL_API int wc_RsaSSL_Sign(const byte* in, word32 inLen, byte* out,
word32 outLen, RsaKey* key, WC_RNG* rng);
WOLFSSL_API int wc_RsaPSS_Sign(const byte* in, word32 inLen, byte* out,
word32 outLen, enum wc_HashType hash, int mgf,
RsaKey* key, WC_RNG* rng);
WOLFSSL_API int wc_RsaSSL_VerifyInline(byte* in, word32 inLen, byte** out,
RsaKey* key);
WOLFSSL_API int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out,