From f3f80bd66e425082cc8a116edd98a3be2f462713 Mon Sep 17 00:00:00 2001 From: toddouska Date: Mon, 26 Aug 2013 16:27:29 -0700 Subject: [PATCH] add Rsa Sign/Verify callbacks, client/server examples --- cyassl/internal.h | 11 +++ cyassl/ssl.h | 18 +++++ cyassl/test.h | 60 +++++++++++++++ examples/server/server.c | 17 ++++- src/internal.c | 156 ++++++++++++++++++++++++++++++++++++--- src/ssl.c | 49 ++++++++++++ 6 files changed, 298 insertions(+), 13 deletions(-) diff --git a/cyassl/internal.h b/cyassl/internal.h index 62516a1d6..1fbb5cb07 100644 --- a/cyassl/internal.h +++ b/cyassl/internal.h @@ -1269,6 +1269,10 @@ struct CYASSL_CTX { CallbackEccSign EccSignCb; /* User EccSign Callback handler */ CallbackEccVerify EccVerifyCb; /* User EccVerify Callback handler */ #endif /* HAVE_ECC */ + #ifndef NO_RSA + CallbackRsaSign RsaSignCb; /* User RsaSign Callback handler */ + CallbackRsaVerify RsaVerifyCb; /* User RsaVerify Callback handler */ + #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ }; @@ -1566,6 +1570,9 @@ typedef struct Buffers { #ifdef HAVE_ECC buffer peerEccDsaKey; /* we own for Ecc Verify Callbacks */ #endif /* HAVE_ECC */ + #ifndef NO_RSA + buffer peerRsaKey; /* we own for Rsa Verify Callbacks */ + #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ } Buffers; @@ -1847,6 +1854,10 @@ struct CYASSL { void* EccSignCtx; /* Ecc Sign Callback Context */ void* EccVerifyCtx; /* Ecc Verify Callback Context */ #endif /* HAVE_ECC */ + #ifndef NO_RSA + void* RsaSignCtx; /* Rsa Sign Callback Context */ + void* RsaVerifyCtx; /* Rsa Verify Callback Context */ + #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ }; diff --git a/cyassl/ssl.h b/cyassl/ssl.h index 943739a6f..e4e9c77ed 100644 --- a/cyassl/ssl.h +++ b/cyassl/ssl.h @@ -1024,6 +1024,24 @@ CYASSL_API void CyaSSL_CTX_SetEccVerifyCb(CYASSL_CTX*, CallbackEccVerify); CYASSL_API void CyaSSL_SetEccVerifyCtx(CYASSL* ssl, void *ctx); CYASSL_API void* CyaSSL_GetEccVerifyCtx(CYASSL* ssl); +typedef int (*CallbackRsaSign)(CYASSL* ssl, + const unsigned char* in, unsigned int inSz, + unsigned char* out, unsigned int* outSz, + const unsigned char* keyDer, unsigned int keySz, + void* ctx); +CYASSL_API void CyaSSL_CTX_SetRsaSignCb(CYASSL_CTX*, CallbackRsaSign); +CYASSL_API void CyaSSL_SetRsaSignCtx(CYASSL* ssl, void *ctx); +CYASSL_API void* CyaSSL_GetRsaSignCtx(CYASSL* ssl); + +typedef int (*CallbackRsaVerify)(CYASSL* ssl, + unsigned char* sig, unsigned int sigSz, + unsigned char** out, + const unsigned char* keyDer, unsigned int keySz, + void* ctx); +CYASSL_API void CyaSSL_CTX_SetRsaVerifyCb(CYASSL_CTX*, CallbackRsaVerify); +CYASSL_API void CyaSSL_SetRsaVerifyCtx(CYASSL* ssl, void *ctx); +CYASSL_API void* CyaSSL_GetRsaVerifyCtx(CYASSL* ssl); + #ifndef NO_CERTS CYASSL_API void CyaSSL_CTX_SetCACb(CYASSL_CTX*, CallbackCACache); diff --git a/cyassl/test.h b/cyassl/test.h index bfd667951..13b5cf901 100644 --- a/cyassl/test.h +++ b/cyassl/test.h @@ -1502,6 +1502,8 @@ static INLINE void FreeAtomicUser(CYASSL* ssl) #ifdef HAVE_PK_CALLBACKS +#ifdef HAVE_ECC + static INLINE int myEccSign(CYASSL* ssl, const byte* in, word32 inSz, byte* out, word32* outSz, const byte* key, word32 keySz, void* ctx) { @@ -1545,6 +1547,60 @@ static INLINE int myEccVerify(CYASSL* ssl, const byte* sig, word32 sigSz, return ret; } +#endif /* HAVE_ECC */ + +#ifndef NO_RSA + +static INLINE int myRsaSign(CYASSL* ssl, const byte* in, word32 inSz, + byte* out, word32* outSz, const byte* key, word32 keySz, void* ctx) +{ + RNG rng; + int ret; + word32 idx = 0; + RsaKey myKey; + + (void)ssl; + (void)ctx; + + InitRng(&rng); + InitRsaKey(&myKey, NULL); + + ret = RsaPrivateKeyDecode(key, &idx, &myKey, keySz); + if (ret == 0) + ret = RsaSSL_Sign(in, inSz, out, *outSz, &myKey, &rng); + if (ret > 0) { /* save and convert to 0 success */ + *outSz = ret; + ret = 0; + } + FreeRsaKey(&myKey); + + return ret; +} + + +static INLINE int myRsaVerify(CYASSL* ssl, byte* sig, word32 sigSz, + byte** out, + const byte* key, word32 keySz, + void* ctx) +{ + int ret; + word32 idx = 0; + RsaKey myKey; + + (void)ssl; + (void)ctx; + + InitRsaKey(&myKey, NULL); + + ret = RsaPublicKeyDecode(key, &idx, &myKey, keySz); + if (ret == 0) + ret = RsaSSL_VerifyInline(sig, sigSz, out, &myKey); + FreeRsaKey(&myKey); + + return ret; +} + +#endif /* NO_RSA */ static INLINE void SetupPkCallbacks(CYASSL_CTX* ctx, CYASSL* ssl) { @@ -1555,6 +1611,10 @@ static INLINE void SetupPkCallbacks(CYASSL_CTX* ctx, CYASSL* ssl) CyaSSL_CTX_SetEccSignCb(ctx, myEccSign); CyaSSL_CTX_SetEccVerifyCb(ctx, myEccVerify); #endif /* HAVE_ECC */ + #ifndef NO_RSA + CyaSSL_CTX_SetRsaSignCb(ctx, myRsaSign); + CyaSSL_CTX_SetRsaVerifyCb(ctx, myRsaVerify); + #endif /* NO_RSA */ } #endif /* HAVE_PK_CALLBACKS */ diff --git a/examples/server/server.c b/examples/server/server.c index b8d06a9b0..9be9c4802 100644 --- a/examples/server/server.c +++ b/examples/server/server.c @@ -127,6 +127,9 @@ static void Usage(void) printf("-o Perform OCSP lookup on peer certificate\n"); printf("-O Perform OCSP lookup using as responder\n"); #endif +#ifdef HAVE_PK_CALLBACKS + printf("-P Public Key Callbacks\n"); +#endif } #ifdef CYASSL_MDK_SHELL @@ -157,6 +160,7 @@ THREAD_RETURN CYASSL_THREAD server_test(void* args) int nonBlocking = 0; int trackMemory = 0; int fewerPackets = 0; + int pkCallbacks = 0; char* cipherList = NULL; char* verifyCert = (char*)cliCert; char* ourCert = (char*)svrCert; @@ -181,8 +185,9 @@ THREAD_RETURN CYASSL_THREAD server_test(void* args) ourKey = (char*)eccKey; #endif (void)trackMemory; + (void)pkCallbacks; - while ((ch = mygetopt(argc, argv, "?dbstnNufp:v:l:A:c:k:S:oO:")) != -1) { + while ((ch = mygetopt(argc, argv, "?dbstnNufPp:v:l:A:c:k:S:oO:")) != -1) { switch (ch) { case '?' : Usage(); @@ -218,6 +223,12 @@ THREAD_RETURN CYASSL_THREAD server_test(void* args) fewerPackets = 1; break; + case 'P' : + #ifdef HAVE_PK_CALLBACKS + pkCallbacks = 1; + #endif + break; + case 'p' : port = atoi(myoptarg); #if !defined(NO_MAIN_DRIVER) || defined(USE_WINDOWS_API) @@ -454,6 +465,10 @@ THREAD_RETURN CYASSL_THREAD server_test(void* args) CyaSSL_CTX_OCSP_set_override_url(ctx, ocspUrl); } #endif +#ifdef HAVE_PK_CALLBACKS + if (pkCallbacks) + SetupPkCallbacks(ctx, ssl); +#endif tcp_accept(&sockfd, &clientfd, (func_args*)args, port, useAnyAddr, doDTLS); if (!doDTLS) diff --git a/src/internal.c b/src/internal.c index 9e512b603..d663739d5 100644 --- a/src/internal.c +++ b/src/internal.c @@ -437,6 +437,10 @@ int InitSSL_Ctx(CYASSL_CTX* ctx, CYASSL_METHOD* method) ctx->EccSignCb = NULL; ctx->EccVerifyCb = NULL; #endif /* HAVE_ECC */ + #ifndef NO_RSA + ctx->RsaSignCb = NULL; + ctx->RsaVerifyCb = NULL; + #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ if (InitMutex(&ctx->countMutex) < 0) { @@ -1282,6 +1286,10 @@ int InitSSL(CYASSL* ssl, CYASSL_CTX* ctx) ssl->buffers.peerEccDsaKey.buffer = 0; ssl->buffers.peerEccDsaKey.length = 0; #endif /* HAVE_ECC */ + #ifndef NO_RSA + ssl->buffers.peerRsaKey.buffer = 0; + ssl->buffers.peerRsaKey.length = 0; + #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ #ifdef KEEP_PEER_CERT @@ -1501,6 +1509,10 @@ int InitSSL(CYASSL* ssl, CYASSL_CTX* ctx) ssl->EccSignCtx = NULL; ssl->EccVerifyCtx = NULL; #endif /* HAVE_ECC */ + #ifndef NO_RSA + ssl->RsaSignCtx = NULL; + ssl->RsaVerifyCtx = NULL; + #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ /* all done with init, now can return errors, call other stuff */ @@ -1715,6 +1727,9 @@ void SSL_ResourceFree(CYASSL* ssl) #ifdef HAVE_ECC XFREE(ssl->buffers.peerEccDsaKey.buffer, ssl->heap, DYNAMIC_TYPE_ECC); #endif /* HAVE_ECC */ + #ifndef NO_RSA + XFREE(ssl->buffers.peerRsaKey.buffer, ssl->heap, DYNAMIC_TYPE_RSA); + #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ #ifdef HAVE_TLS_EXTENSIONS TLSX_FreeAll(ssl->extensions); @@ -1808,6 +1823,10 @@ void FreeHandshakeResources(CYASSL* ssl) XFREE(ssl->buffers.peerEccDsaKey.buffer, ssl->heap, DYNAMIC_TYPE_ECC); ssl->buffers.peerEccDsaKey.buffer = NULL; #endif /* HAVE_ECC */ + #ifndef NO_RSA + XFREE(ssl->buffers.peerRsaKey.buffer, ssl->heap, DYNAMIC_TYPE_RSA); + ssl->buffers.peerRsaKey.buffer = NULL; + #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ } @@ -3238,8 +3257,24 @@ static int DoCertificate(CYASSL* ssl, byte* input, word32* inOutIdx) ssl->peerRsaKey, dCert.pubKeySize) != 0) { ret = PEER_KEY_ERROR; } - else + else { ssl->peerRsaKeyPresent = 1; + #ifdef HAVE_PK_CALLBACKS + #ifndef NO_RSA + ssl->buffers.peerRsaKey.buffer = + XMALLOC(dCert.pubKeySize, + ssl->heap, DYNAMIC_TYPE_RSA); + if (ssl->buffers.peerRsaKey.buffer == NULL) + ret = MEMORY_ERROR; + else { + XMEMCPY(ssl->buffers.peerRsaKey.buffer, + dCert.publicKey, dCert.pubKeySize); + ssl->buffers.peerRsaKey.length = + dCert.pubKeySize; + } + #endif /* NO_RSA */ + #endif /*HAVE_PK_CALLBACKS */ + } } break; #endif /* NO_RSA */ @@ -7453,11 +7488,29 @@ static void PickHashSigAlgo(CYASSL* ssl, { int ret; byte* out; + byte doUserRsa = 0; + + #ifdef HAVE_PK_CALLBACKS + if (ssl->ctx->RsaVerifyCb) + doUserRsa = 1; + #endif /*HAVE_PK_CALLBACKS */ if (!ssl->peerRsaKeyPresent) return NO_PEER_KEY; - ret = RsaSSL_VerifyInline(signature, sigLen,&out, ssl->peerRsaKey); + if (doUserRsa) { + #ifdef HAVE_PK_CALLBACKS + ret = ssl->ctx->RsaVerifyCb(ssl, signature, sigLen, + &out, + ssl->buffers.peerRsaKey.buffer, + ssl->buffers.peerRsaKey.length, + ssl->RsaVerifyCtx); + #endif /*HAVE_PK_CALLBACKS */ + } + else { + ret = RsaSSL_VerifyInline(signature, sigLen,&out, + ssl->peerRsaKey); + } if (IsAtLeastTLSv1_2(ssl)) { byte encodedSig[MAX_ENCODED_SIG_SZ]; @@ -7982,6 +8035,13 @@ static void PickHashSigAlgo(CYASSL* ssl, } #ifndef NO_RSA else { + byte doUserRsa = 0; + + #ifdef HAVE_PK_CALLBACKS + if (ssl->ctx->RsaSignCb) + doUserRsa = 1; + #endif /*HAVE_PK_CALLBACKS */ + if (IsAtLeastTLSv1_2(ssl)) { #ifndef NO_OLD_TLS byte* digest = ssl->certHashes.sha; @@ -8020,8 +8080,23 @@ static void PickHashSigAlgo(CYASSL* ssl, } c16toa((word16)length, verify + extraSz); /* prepend hdr */ - ret = RsaSSL_Sign(signBuffer, signSz, verify + extraSz + + if (doUserRsa) { + #ifdef HAVE_PK_CALLBACKS + #ifndef NO_RSA + word32 ioLen = ENCRYPT_LEN; + ret = ssl->ctx->RsaSignCb(ssl, signBuffer, signSz, + verify + extraSz + VERIFY_HEADER, + &ioLen, + ssl->buffers.key.buffer, + ssl->buffers.key.length, + ssl->RsaSignCtx); + #endif /* NO_RSA */ + #endif /*HAVE_PK_CALLBACKS */ + } + else { + ret = RsaSSL_Sign(signBuffer, signSz, verify + extraSz + VERIFY_HEADER, ENCRYPT_LEN, &key, ssl->rng); + } if (ret > 0) ret = 0; /* RSA reset */ @@ -8436,6 +8511,13 @@ static void PickHashSigAlgo(CYASSL* ssl, byte* signBuffer = hash; word32 signSz = sizeof(hash); byte encodedSig[MAX_ENCODED_SIG_SZ]; + byte doUserRsa = 0; + + #ifdef HAVE_PK_CALLBACKS + if (ssl->ctx->RsaSignCb) + doUserRsa = 1; + #endif /*HAVE_PK_CALLBACKS */ + if (IsAtLeastTLSv1_2(ssl)) { byte* digest = &hash[MD5_DIGEST_SIZE]; int typeH = SHAh; @@ -8464,13 +8546,26 @@ static void PickHashSigAlgo(CYASSL* ssl, c16toa((word16)sigSz, output + idx); idx += LENGTH_SZ; - ret = RsaSSL_Sign(signBuffer, signSz, output + idx, sigSz, - &rsaKey, ssl->rng); + if (doUserRsa) { + #ifdef HAVE_PK_CALLBACKS + word32 ioLen = sigSz; + ret = ssl->ctx->RsaSignCb(ssl, signBuffer, signSz, + output + idx, + &ioLen, + ssl->buffers.key.buffer, + ssl->buffers.key.length, + ssl->RsaSignCtx); + #endif /*HAVE_PK_CALLBACKS */ + } + else { + ret = RsaSSL_Sign(signBuffer, signSz, output + idx, + sigSz, &rsaKey, ssl->rng); + if (ret > 0) + ret = 0; /* reset on success */ + } FreeRsaKey(&rsaKey); ecc_free(&dsaKey); - if (ret > 0) - ret = 0; /* reset on success */ - else + if (ret < 0) return ret; } else #endif @@ -8740,6 +8835,13 @@ static void PickHashSigAlgo(CYASSL* ssl, byte* signBuffer = hash; word32 signSz = sizeof(hash); byte encodedSig[MAX_ENCODED_SIG_SZ]; + byte doUserRsa = 0; + + #ifdef HAVE_PK_CALLBACKS + if (ssl->ctx->RsaSignCb) + doUserRsa = 1; + #endif /*HAVE_PK_CALLBACKS */ + if (IsAtLeastTLSv1_2(ssl)) { byte* digest = &hash[MD5_DIGEST_SIZE]; int typeH = SHAh; @@ -8764,10 +8866,23 @@ static void PickHashSigAlgo(CYASSL* ssl, typeH); signBuffer = encodedSig; } - ret = RsaSSL_Sign(signBuffer, signSz, output + idx, sigSz, - &rsaKey, ssl->rng); + if (doUserRsa) { + #ifdef HAVE_PK_CALLBACKS + word32 ioLen = sigSz; + ret = ssl->ctx->RsaSignCb(ssl, signBuffer, signSz, + output + idx, + &ioLen, + ssl->buffers.key.buffer, + ssl->buffers.key.length, + ssl->RsaSignCtx); + #endif /*HAVE_PK_CALLBACKS */ + } + else { + ret = RsaSSL_Sign(signBuffer, signSz, output + idx, + sigSz, &rsaKey, ssl->rng); + } FreeRsaKey(&rsaKey); - if (ret <= 0) + if (ret < 0) return ret; } #endif @@ -9768,10 +9883,27 @@ static void PickHashSigAlgo(CYASSL* ssl, if (ssl->peerRsaKeyPresent != 0) { byte* out; int outLen; + byte doUserRsa = 0; + + #ifdef HAVE_PK_CALLBACKS + if (ssl->ctx->RsaVerifyCb) + doUserRsa = 1; + #endif /*HAVE_PK_CALLBACKS */ CYASSL_MSG("Doing RSA peer cert verify"); - outLen = RsaSSL_VerifyInline(sig, sz, &out, ssl->peerRsaKey); + if (doUserRsa) { + #ifdef HAVE_PK_CALLBACKS + outLen = ssl->ctx->RsaVerifyCb(ssl, sig, sz, + &out, + ssl->buffers.peerRsaKey.buffer, + ssl->buffers.peerRsaKey.length, + ssl->RsaVerifyCtx); + #endif /*HAVE_PK_CALLBACKS */ + } + else { + outLen = RsaSSL_VerifyInline(sig, sz, &out, ssl->peerRsaKey); + } if (IsAtLeastTLSv1_2(ssl)) { byte encodedSig[MAX_ENCODED_SIG_SZ]; diff --git a/src/ssl.c b/src/ssl.c index 43d49450a..f939b2c9d 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -10437,6 +10437,55 @@ void* CyaSSL_GetEccVerifyCtx(CYASSL* ssl) #endif /* HAVE_ECC */ +#ifndef NO_RSA + +void CyaSSL_CTX_SetRsaSignCb(CYASSL_CTX* ctx, CallbackRsaSign cb) +{ + if (ctx) + ctx->RsaSignCb = cb; +} + + +void CyaSSL_SetRsaSignCtx(CYASSL* ssl, void *ctx) +{ + if (ssl) + ssl->RsaSignCtx = ctx; +} + + +void* CyaSSL_GetRsaSignCtx(CYASSL* ssl) +{ + if (ssl) + return ssl->RsaSignCtx; + + return NULL; +} + + +void CyaSSL_CTX_SetRsaVerifyCb(CYASSL_CTX* ctx, CallbackRsaVerify cb) +{ + if (ctx) + ctx->RsaVerifyCb = cb; +} + + +void CyaSSL_SetRsaVerifyCtx(CYASSL* ssl, void *ctx) +{ + if (ssl) + ssl->RsaVerifyCtx = ctx; +} + + +void* CyaSSL_GetRsaVerifyCtx(CYASSL* ssl) +{ + if (ssl) + return ssl->RsaVerifyCtx; + + return NULL; +} + +#endif /* NO_RSA */ + #endif /* HAVE_PK_CALLBACKS */ #endif /* NO_CERTS */