From e8fcf35098deb5f9a1344933451c1eb3b2c00d60 Mon Sep 17 00:00:00 2001 From: toddouska Date: Mon, 26 Aug 2013 17:14:19 -0700 Subject: [PATCH] add Rsa Public/Private client key exchange callbacks, examples --- cyassl/internal.h | 4 +++ cyassl/ssl.h | 20 ++++++++++++++ cyassl/test.h | 53 ++++++++++++++++++++++++++++++++++++ src/internal.c | 69 ++++++++++++++++++++++++++++++++++++++++------- src/ssl.c | 45 +++++++++++++++++++++++++++++++ 5 files changed, 182 insertions(+), 9 deletions(-) diff --git a/cyassl/internal.h b/cyassl/internal.h index 1fbb5cb07..8c76de40f 100644 --- a/cyassl/internal.h +++ b/cyassl/internal.h @@ -1272,6 +1272,8 @@ struct CYASSL_CTX { #ifndef NO_RSA CallbackRsaSign RsaSignCb; /* User RsaSign Callback handler */ CallbackRsaVerify RsaVerifyCb; /* User RsaVerify Callback handler */ + CallbackRsaEnc RsaEncCb; /* User Rsa Public Encrypt handler */ + CallbackRsaDec RsaDecCb; /* User Rsa Private Decrypt handler */ #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ }; @@ -1857,6 +1859,8 @@ struct CYASSL { #ifndef NO_RSA void* RsaSignCtx; /* Rsa Sign Callback Context */ void* RsaVerifyCtx; /* Rsa Verify Callback Context */ + void* RsaEncCtx; /* Rsa Public Encrypt Callback Context */ + void* RsaDecCtx; /* Rsa Private Decrypt Callback Context */ #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ }; diff --git a/cyassl/ssl.h b/cyassl/ssl.h index e4e9c77ed..340f3fa68 100644 --- a/cyassl/ssl.h +++ b/cyassl/ssl.h @@ -1042,6 +1042,26 @@ CYASSL_API void CyaSSL_CTX_SetRsaVerifyCb(CYASSL_CTX*, CallbackRsaVerify); CYASSL_API void CyaSSL_SetRsaVerifyCtx(CYASSL* ssl, void *ctx); CYASSL_API void* CyaSSL_GetRsaVerifyCtx(CYASSL* ssl); +/* RSA Public Encrypt cb */ +typedef int (*CallbackRsaEnc)(CYASSL* ssl, + const unsigned char* in, unsigned int inSz, + unsigned char* out, unsigned int* outSz, + const unsigned char* keyDer, unsigned int keySz, + void* ctx); +CYASSL_API void CyaSSL_CTX_SetRsaEncCb(CYASSL_CTX*, CallbackRsaEnc); +CYASSL_API void CyaSSL_SetRsaEncCtx(CYASSL* ssl, void *ctx); +CYASSL_API void* CyaSSL_GetRsaEncCtx(CYASSL* ssl); + +/* RSA Private Decrypt cb */ +typedef int (*CallbackRsaDec)(CYASSL* ssl, + unsigned char* in, unsigned int inSz, + unsigned char** out, + const unsigned char* keyDer, unsigned int keySz, + void* ctx); +CYASSL_API void CyaSSL_CTX_SetRsaDecCb(CYASSL_CTX*, CallbackRsaDec); +CYASSL_API void CyaSSL_SetRsaDecCtx(CYASSL* ssl, void *ctx); +CYASSL_API void* CyaSSL_GetRsaDecCtx(CYASSL* ssl); + #ifndef NO_CERTS CYASSL_API void CyaSSL_CTX_SetCACb(CYASSL_CTX*, CallbackCACache); diff --git a/cyassl/test.h b/cyassl/test.h index 13b5cf901..50e51726f 100644 --- a/cyassl/test.h +++ b/cyassl/test.h @@ -1600,6 +1600,57 @@ static INLINE int myRsaVerify(CYASSL* ssl, byte* sig, word32 sigSz, return ret; } + +static INLINE int myRsaEnc(CYASSL* ssl, const byte* in, word32 inSz, + byte* out, word32* outSz, const byte* key, + word32 keySz, void* ctx) +{ + int ret; + word32 idx = 0; + RsaKey myKey; + RNG rng; + + (void)ssl; + (void)ctx; + + InitRng(&rng); + InitRsaKey(&myKey, NULL); + + ret = RsaPublicKeyDecode(key, &idx, &myKey, keySz); + if (ret == 0) { + ret = RsaPublicEncrypt(in, inSz, out, *outSz, &myKey, &rng); + if (ret > 0) { + *outSz = ret; + ret = 0; /* reset to success */ + } + } + FreeRsaKey(&myKey); + + return ret; +} + +static INLINE int myRsaDec(CYASSL* ssl, byte* in, word32 inSz, + byte** out, + const byte* key, word32 keySz, void* ctx) +{ + int ret; + word32 idx = 0; + RsaKey myKey; + + (void)ssl; + (void)ctx; + + InitRsaKey(&myKey, NULL); + + ret = RsaPrivateKeyDecode(key, &idx, &myKey, keySz); + if (ret == 0) { + ret = RsaPrivateDecryptInline(in, inSz, out, &myKey); + } + FreeRsaKey(&myKey); + + return ret; +} + #endif /* NO_RSA */ static INLINE void SetupPkCallbacks(CYASSL_CTX* ctx, CYASSL* ssl) @@ -1614,6 +1665,8 @@ static INLINE void SetupPkCallbacks(CYASSL_CTX* ctx, CYASSL* ssl) #ifndef NO_RSA CyaSSL_CTX_SetRsaSignCb(ctx, myRsaSign); CyaSSL_CTX_SetRsaVerifyCb(ctx, myRsaVerify); + CyaSSL_CTX_SetRsaEncCb(ctx, myRsaEnc); + CyaSSL_CTX_SetRsaDecCb(ctx, myRsaDec); #endif /* NO_RSA */ } diff --git a/src/internal.c b/src/internal.c index d663739d5..60b8eca50 100644 --- a/src/internal.c +++ b/src/internal.c @@ -440,6 +440,8 @@ int InitSSL_Ctx(CYASSL_CTX* ctx, CYASSL_METHOD* method) #ifndef NO_RSA ctx->RsaSignCb = NULL; ctx->RsaVerifyCb = NULL; + ctx->RsaEncCb = NULL; + ctx->RsaDecCb = NULL; #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ @@ -1512,6 +1514,8 @@ int InitSSL(CYASSL* ssl, CYASSL_CTX* ctx) #ifndef NO_RSA ssl->RsaSignCtx = NULL; ssl->RsaVerifyCtx = NULL; + ssl->RsaEncCtx = NULL; + ssl->RsaDecCtx = NULL; #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */ @@ -7641,6 +7645,14 @@ static void PickHashSigAlgo(CYASSL* ssl, word32 encSz = 0; word32 idx = 0; int ret = 0; + byte doUserRsa = 0; + + #ifdef HAVE_PK_CALLBACKS + #ifndef NO_RSA + if (ssl->ctx->RsaEncCb) + doUserRsa = 1; + #endif /* NO_RSA */ + #endif /*HAVE_PK_CALLBACKS */ switch (ssl->specs.kea) { #ifndef NO_RSA @@ -7654,12 +7666,28 @@ static void PickHashSigAlgo(CYASSL* ssl, if (ssl->peerRsaKeyPresent == 0) return NO_PEER_KEY; - ret = RsaPublicEncrypt(ssl->arrays->preMasterSecret, SECRET_LEN, - encSecret, sizeof(encSecret), ssl->peerRsaKey, - ssl->rng); - if (ret > 0) { - encSz = ret; - ret = 0; /* set success to 0 */ + if (doUserRsa) { + #ifdef HAVE_PK_CALLBACKS + #ifndef NO_RSA + encSz = sizeof(encSecret); + ret = ssl->ctx->RsaEncCb(ssl, + ssl->arrays->preMasterSecret, + SECRET_LEN, + encSecret, &encSz, + ssl->buffers.peerRsaKey.buffer, + ssl->buffers.peerRsaKey.length, + ssl->RsaEncCtx); + #endif /* NO_RSA */ + #endif /*HAVE_PK_CALLBACKS */ + } + else { + ret = RsaPublicEncrypt(ssl->arrays->preMasterSecret, + SECRET_LEN, encSecret, sizeof(encSecret), + ssl->peerRsaKey, ssl->rng); + if (ret > 0) { + encSz = ret; + ret = 0; /* set success to 0 */ + } } break; #endif @@ -10133,6 +10161,14 @@ static void PickHashSigAlgo(CYASSL* ssl, word32 idx = 0; RsaKey key; byte* tmp = 0; + byte doUserRsa = 0; + + #ifdef HAVE_PK_CALLBACKS + #ifndef NO_RSA + if (ssl->ctx->RsaDecCb) + doUserRsa = 1; + #endif /* NO_RSA */ + #endif /*HAVE_PK_CALLBACKS */ InitRsaKey(&key, ssl->heap); @@ -10165,8 +10201,22 @@ static void PickHashSigAlgo(CYASSL* ssl, return INCOMPLETE_DATA; } - if (RsaPrivateDecryptInline(tmp, length, &out, &key) == - SECRET_LEN) { + if (doUserRsa) { + #ifdef HAVE_PK_CALLBACKS + #ifndef NO_RSA + ret = ssl->ctx->RsaDecCb(ssl, + tmp, length, &out, + ssl->buffers.key.buffer, + ssl->buffers.key.length, + ssl->RsaDecCtx); + #endif /* NO_RSA */ + #endif /*HAVE_PK_CALLBACKS */ + } + else { + ret = RsaPrivateDecryptInline(tmp, length, &out, &key); + } + + if (ret == SECRET_LEN) { XMEMCPY(ssl->arrays->preMasterSecret, out, SECRET_LEN); if (ssl->arrays->preMasterSecret[0] != ssl->chVersion.major @@ -10176,8 +10226,9 @@ static void PickHashSigAlgo(CYASSL* ssl, else ret = MakeMasterSecret(ssl); } - else + else { ret = RSA_PRIVATE_ERROR; + } } FreeRsaKey(&key); diff --git a/src/ssl.c b/src/ssl.c index f939b2c9d..53f970c9b 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -10484,6 +10484,51 @@ void* CyaSSL_GetRsaVerifyCtx(CYASSL* ssl) return NULL; } +void CyaSSL_CTX_SetRsaEncCb(CYASSL_CTX* ctx, CallbackRsaEnc cb) +{ + if (ctx) + ctx->RsaEncCb = cb; +} + + +void CyaSSL_SetRsaEncCtx(CYASSL* ssl, void *ctx) +{ + if (ssl) + ssl->RsaEncCtx = ctx; +} + + +void* CyaSSL_GetRsaEncCtx(CYASSL* ssl) +{ + if (ssl) + return ssl->RsaEncCtx; + + return NULL; +} + +void CyaSSL_CTX_SetRsaDecCb(CYASSL_CTX* ctx, CallbackRsaDec cb) +{ + if (ctx) + ctx->RsaDecCb = cb; +} + + +void CyaSSL_SetRsaDecCtx(CYASSL* ssl, void *ctx) +{ + if (ssl) + ssl->RsaDecCtx = ctx; +} + + +void* CyaSSL_GetRsaDecCtx(CYASSL* ssl) +{ + if (ssl) + return ssl->RsaDecCtx; + + return NULL; +} + + #endif /* NO_RSA */ #endif /* HAVE_PK_CALLBACKS */