diff --git a/src/internal.c b/src/internal.c index 7ccc871a4..71b851c27 100644 --- a/src/internal.c +++ b/src/internal.c @@ -3828,7 +3828,20 @@ int DhAgree(WOLFSSL* ssl, DhKey* dhKey, return ret; #endif - ret = wc_DhAgree(dhKey, agree, agreeSz, priv, privSz, otherPub, otherPubSz); +#ifdef HAVE_PK_CALLBACKS + if (ssl->ctx->DhAgreeCb) { + void* ctx = wolfSSL_GetDhAgreeCtx(ssl); + + WOLFSSL_MSG("Calling DhAgree Callback Function"); + ret = ssl->ctx->DhAgreeCb(ssl, dhKey, priv, privSz, + otherPub, otherPubSz, agree, agreeSz, ctx); + } + else +#endif + { + ret = wc_DhAgree(dhKey, agree, agreeSz, priv, privSz, otherPub, + otherPubSz); + } /* Handle async pending response */ #ifdef WOLFSSL_ASYNC_CRYPT diff --git a/src/ssl.c b/src/ssl.c index edd0a3a12..a4a8a8791 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -23864,6 +23864,29 @@ void* wolfSSL_GetEccSharedSecretCtx(WOLFSSL* ssl) } #endif /* HAVE_ECC */ +#ifndef NO_DH + +void wolfSSL_CTX_SetDhAgreeCb(WOLFSSL_CTX* ctx, CallbackDhAgree cb) +{ + if (ctx) + ctx->DhAgreeCb = cb; +} + +void wolfSSL_SetDhAgreeCtx(WOLFSSL* ssl, void *ctx) +{ + if (ssl) + ssl->DhAgreeCtx = ctx; +} + +void* wolfSSL_GetDhAgreeCtx(WOLFSSL* ssl) +{ + if (ssl) + return ssl->DhAgreeCtx; + + return NULL; +} +#endif /* !NO_DH */ + #ifdef HAVE_ED25519 void wolfSSL_CTX_SetEd25519SignCb(WOLFSSL_CTX* ctx, CallbackEd25519Sign cb) { diff --git a/tests/api.c b/tests/api.c index bd6777154..e91fde627 100644 --- a/tests/api.c +++ b/tests/api.c @@ -988,6 +988,7 @@ static THREAD_RETURN WOLFSSL_THREAD test_server_nofail(void* args) SOCKET_T clientfd = 0; word16 port; + callback_functions* cbf = NULL; WOLFSSL_METHOD* method = 0; WOLFSSL_CTX* ctx = 0; WOLFSSL* ssl = 0; @@ -1002,9 +1003,9 @@ static THREAD_RETURN WOLFSSL_THREAD test_server_nofail(void* args) #endif ((func_args*)args)->return_code = TEST_FAIL; - if (((func_args*)args)->callbacks != NULL && - ((func_args*)args)->callbacks->method != NULL) { - method = ((func_args*)args)->callbacks->method(); + cbf = ((func_args*)args)->callbacks; + if (cbf != NULL && cbf->method != NULL) { + method = cbf->method(); } else { method = wolfSSLv23_server_method(); @@ -1049,6 +1050,11 @@ static THREAD_RETURN WOLFSSL_THREAD test_server_nofail(void* args) goto done; } + /* call ctx setup callback */ + if (cbf != NULL && cbf->ctx_ready != NULL) { + cbf->ctx_ready(ctx); + } + ssl = wolfSSL_new(ctx); tcp_accept(&sockfd, &clientfd, (func_args*)args, port, 0, 0, 0, 0, 1); CloseSocket(sockfd); @@ -1066,6 +1072,11 @@ static THREAD_RETURN WOLFSSL_THREAD test_server_nofail(void* args) #endif #endif + /* call ssl setup callback */ + if (cbf != NULL && cbf->ssl_ready != NULL) { + cbf->ssl_ready(ssl); + } + do { #ifdef WOLFSSL_ASYNC_CRYPT if (err == WC_PENDING_E) { @@ -1108,13 +1119,14 @@ static THREAD_RETURN WOLFSSL_THREAD test_server_nofail(void* args) Task_yield(); #endif + ((func_args*)args)->return_code = TEST_SUCCESS; + done: wolfSSL_shutdown(ssl); wolfSSL_free(ssl); wolfSSL_CTX_free(ctx); CloseSocket(clientfd); - ((func_args*)args)->return_code = TEST_SUCCESS; #ifdef WOLFSSL_TIRTOS fdCloseSession(Task_self()); @@ -1136,6 +1148,7 @@ typedef int (*cbType)(WOLFSSL_CTX *ctx, WOLFSSL *ssl); static void test_client_nofail(void* args, void *cb) { SOCKET_T sockfd = 0; + callback_functions* cbf = NULL; WOLFSSL_METHOD* method = 0; WOLFSSL_CTX* ctx = 0; @@ -1150,11 +1163,13 @@ static void test_client_nofail(void* args, void *cb) #ifdef WOLFSSL_TIRTOS fdOpenSession(Task_self()); #endif + if (((func_args*)args)->callbacks != NULL) { + cbf = ((func_args*)args)->callbacks; + } ((func_args*)args)->return_code = TEST_FAIL; - if (((func_args*)args)->callbacks != NULL && - ((func_args*)args)->callbacks->method != NULL) { - method = ((func_args*)args)->callbacks->method(); + if (cbf != NULL && cbf->method != NULL) { + method = cbf->method(); } else { method = wolfSSLv23_client_method(); @@ -1185,6 +1200,11 @@ static void test_client_nofail(void* args, void *cb) goto done2; } + /* call ctx setup callback */ + if (cbf != NULL && cbf->ctx_ready != NULL) { + cbf->ctx_ready(ctx); + } + ssl = wolfSSL_new(ctx); tcp_connect(&sockfd, wolfSSLIP, ((func_args*)args)->signal->port, 0, 0, ssl); @@ -1193,6 +1213,11 @@ static void test_client_nofail(void* args, void *cb) goto done2; } + /* call ssl setup callback */ + if (cbf != NULL && cbf->ssl_ready != NULL) { + cbf->ssl_ready(ssl); + } + do { #ifdef WOLFSSL_ASYNC_CRYPT if (err == WC_PENDING_E) { @@ -1230,12 +1255,13 @@ static void test_client_nofail(void* args, void *cb) printf("Server response: %s\n", reply); } + ((func_args*)args)->return_code = TEST_SUCCESS; + done2: wolfSSL_free(ssl); wolfSSL_CTX_free(ctx); CloseSocket(sockfd); - ((func_args*)args)->return_code = TEST_SUCCESS; #ifdef WOLFSSL_TIRTOS fdCloseSession(Task_self()); @@ -13626,6 +13652,8 @@ static void test_wolfSSL_ERR_peek_last_error_line(void) StartTCP(); InitTcpReady(&ready); + XMEMSET(&client_cb, 0, sizeof(callback_functions)); + XMEMSET(&server_cb, 0, sizeof(callback_functions)); client_cb.method = wolfTLSv1_1_client_method; server_cb.method = wolfTLSv1_2_server_method; @@ -14079,6 +14107,8 @@ static void test_wolfSSL_msgCb(void) StartTCP(); InitTcpReady(&ready); + XMEMSET(&client_cb, 0, sizeof(callback_functions)); + XMEMSET(&server_cb, 0, sizeof(callback_functions)); client_cb.method = wolfTLSv1_2_client_method; server_cb.method = wolfTLSv1_2_server_method; @@ -15036,6 +15066,182 @@ static int test_tls13_apis(void) #endif +#ifdef HAVE_PK_CALLBACKS +#if !defined(NO_FILESYSTEM) && !defined(NO_CERTS) && !defined(NO_RSA) && \ + !defined(NO_WOLFSSL_CLIENT) && !defined(NO_DH) && \ + defined(HAVE_IO_TESTS_DEPENDENCIES) && !defined(SINGLE_THREADED) +static int my_DhCallback(WOLFSSL* ssl, struct DhKey* key, + const unsigned char* priv, unsigned int privSz, + const unsigned char* pubKeyDer, unsigned int pubKeySz, + unsigned char* out, unsigned int* outlen, + void* ctx) +{ + /* Test fail when context associated with WOLFSSL is NULL */ + if (ctx == NULL) { + return -1; + } + + (void)ssl; + /* return 0 on success */ + return wc_DhAgree(key, out, outlen, priv, privSz, pubKeyDer, pubKeySz); +}; + +static void test_dh_ctx_setup(WOLFSSL_CTX* ctx) { + wolfSSL_CTX_SetDhAgreeCb(ctx, my_DhCallback); + AssertIntEQ(wolfSSL_CTX_set_cipher_list(ctx, "DHE-RSA-AES128-SHA256"), + WOLFSSL_SUCCESS); +} + + +static void test_dh_ssl_setup(WOLFSSL* ssl) +{ + static int dh_test_ctx = 1; + int ret; + + wolfSSL_SetDhAgreeCtx(ssl, &dh_test_ctx); + AssertIntEQ(*((int*)wolfSSL_GetDhAgreeCtx(ssl)), dh_test_ctx); + ret = wolfSSL_SetTmpDH_file(ssl, dhParamFile, WOLFSSL_FILETYPE_PEM); + if (ret != WOLFSSL_SUCCESS && ret != SIDE_ERROR) { + AssertIntEQ(ret, WOLFSSL_SUCCESS); + } +} + +static void test_dh_ssl_setup_fail(WOLFSSL* ssl) +{ + int ret; + + wolfSSL_SetDhAgreeCtx(ssl, NULL); + AssertNull(wolfSSL_GetDhAgreeCtx(ssl)); + ret = wolfSSL_SetTmpDH_file(ssl, dhParamFile, WOLFSSL_FILETYPE_PEM); + if (ret != WOLFSSL_SUCCESS && ret != SIDE_ERROR) { + AssertIntEQ(ret, WOLFSSL_SUCCESS); + } +} +#endif + +static void test_DhCallbacks(void) +{ +#if !defined(NO_FILESYSTEM) && !defined(NO_CERTS) && !defined(NO_RSA) && \ + !defined(NO_WOLFSSL_CLIENT) && !defined(NO_DH) && \ + defined(HAVE_IO_TESTS_DEPENDENCIES) && !defined(SINGLE_THREADED) + WOLFSSL_CTX *ctx; + WOLFSSL *ssl; + tcp_ready ready; + func_args server_args; + func_args client_args; + THREAD_TYPE serverThread; + callback_functions func_cb; + int test; + + printf(testingFmt, "test_DhCallbacks"); + + AssertNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_client_method())); + wolfSSL_CTX_SetDhAgreeCb(ctx, &my_DhCallback); + + /* load client ca cert */ + AssertIntEQ(wolfSSL_CTX_load_verify_locations(ctx, caCertFile, 0), + WOLFSSL_SUCCESS); + + /* test with NULL arguments */ + wolfSSL_SetDhAgreeCtx(NULL, &test); + AssertNull(wolfSSL_GetDhAgreeCtx(NULL)); + + /* test success case */ + test = 1; + AssertNotNull(ssl = wolfSSL_new(ctx)); + wolfSSL_SetDhAgreeCtx(ssl, &test); + AssertIntEQ(*((int*)wolfSSL_GetDhAgreeCtx(ssl)), test); + + wolfSSL_free(ssl); + wolfSSL_CTX_free(ctx); + + /* test a connection where callback is used */ +#ifdef WOLFSSL_TIRTOS + fdOpenSession(Task_self()); +#endif + XMEMSET(&server_args, 0, sizeof(func_args)); + XMEMSET(&client_args, 0, sizeof(func_args)); + XMEMSET(&func_cb, 0, sizeof(callback_functions)); + + StartTCP(); + InitTcpReady(&ready); + +#if defined(USE_WINDOWS_API) + /* use RNG to get random port if using windows */ + ready.port = GetRandomPort(); +#endif + + server_args.signal = &ready; + client_args.signal = &ready; + server_args.return_code = TEST_FAIL; + client_args.return_code = TEST_FAIL; + + /* set callbacks to use DH functions */ + func_cb.ctx_ready = &test_dh_ctx_setup; + func_cb.ssl_ready = &test_dh_ssl_setup; + client_args.callbacks = &func_cb; + server_args.callbacks = &func_cb; + + start_thread(test_server_nofail, &server_args, &serverThread); + wait_tcp_ready(&server_args); + test_client_nofail(&client_args, NULL); + join_thread(serverThread); + + AssertTrue(client_args.return_code); + AssertTrue(server_args.return_code); + + FreeTcpReady(&ready); + +#ifdef WOLFSSL_TIRTOS + fdOpenSession(Task_self()); +#endif + + /* now set user ctx to not be 1 so that the callback returns fail case */ +#ifdef WOLFSSL_TIRTOS + fdOpenSession(Task_self()); +#endif + XMEMSET(&server_args, 0, sizeof(func_args)); + XMEMSET(&client_args, 0, sizeof(func_args)); + XMEMSET(&func_cb, 0, sizeof(callback_functions)); + + StartTCP(); + InitTcpReady(&ready); + +#if defined(USE_WINDOWS_API) + /* use RNG to get random port if using windows */ + ready.port = GetRandomPort(); +#endif + + server_args.signal = &ready; + client_args.signal = &ready; + server_args.return_code = TEST_FAIL; + client_args.return_code = TEST_FAIL; + + /* set callbacks to use DH functions */ + func_cb.ctx_ready = &test_dh_ctx_setup; + func_cb.ssl_ready = &test_dh_ssl_setup_fail; + client_args.callbacks = &func_cb; + server_args.callbacks = &func_cb; + + start_thread(test_server_nofail, &server_args, &serverThread); + wait_tcp_ready(&server_args); + test_client_nofail(&client_args, NULL); + join_thread(serverThread); + + AssertIntEQ(client_args.return_code, TEST_FAIL); + AssertIntEQ(server_args.return_code, TEST_FAIL); + + FreeTcpReady(&ready); + +#ifdef WOLFSSL_TIRTOS + fdOpenSession(Task_self()); +#endif + + printf(resultFmt, passed); +#endif +} +#endif /* HAVE_PK_CALLBACKS */ + #ifdef HAVE_HASHDRBG static int test_wc_RNG_GenerateBlock() @@ -15162,6 +15368,11 @@ void ApiTest(void) AssertIntEQ(test_RsaSigFailure_cm(), ASN_SIG_CONFIRM_E); #endif /* NO_CERTS */ +#ifdef HAVE_PK_CALLBACKS + /* public key callback tests */ + test_DhCallbacks(); +#endif + /*wolfcrypt */ printf("\n-----------------wolfcrypt unit tests------------------\n"); AssertFalse(test_wolfCrypt_Init()); diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 1d6014a55..01bbb0a4c 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -2418,6 +2418,9 @@ struct WOLFSSL_CTX { CallbackX25519SharedSecret X25519SharedSecretCb; #endif #endif /* HAVE_ECC */ + #ifndef NO_DH + CallbackDhAgree DhAgreeCb; /* User DH Agree Callback handler */ + #endif #ifndef NO_RSA CallbackRsaSign RsaSignCb; /* User RsaSign Callback handler */ CallbackRsaVerify RsaVerifyCb; /* User RsaVerify Callback handler */ @@ -3498,6 +3501,9 @@ struct WOLFSSL { void* X25519SharedSecretCtx; /* X25519 Pms Callback Context */ #endif #endif /* HAVE_ECC */ + #ifndef NO_DH + void* DhAgreeCtx; /* DH Pms Callback Context */ + #endif /* !NO_DH */ #ifndef NO_RSA void* RsaSignCtx; /* Rsa Sign Callback Context */ void* RsaVerifyCtx; /* Rsa Verify Callback Context */ diff --git a/wolfssl/ssl.h b/wolfssl/ssl.h index a7594402d..ebfc5fac2 100644 --- a/wolfssl/ssl.h +++ b/wolfssl/ssl.h @@ -1690,6 +1690,19 @@ WOLFSSL_API void wolfSSL_CTX_SetEccSharedSecretCb(WOLFSSL_CTX*, CallbackEccShar WOLFSSL_API void wolfSSL_SetEccSharedSecretCtx(WOLFSSL* ssl, void *ctx); WOLFSSL_API void* wolfSSL_GetEccSharedSecretCtx(WOLFSSL* ssl); +#ifndef NO_DH +/* Public DH Key Callback support */ +struct DhKey; +typedef int (*CallbackDhAgree)(WOLFSSL* ssl, struct DhKey* key, + const unsigned char* priv, unsigned int privSz, + const unsigned char* otherPubKeyDer, unsigned int otherPubKeySz, + unsigned char* out, unsigned int* outlen, + void* ctx); +WOLFSSL_API void wolfSSL_CTX_SetDhAgreeCb(WOLFSSL_CTX*, CallbackDhAgree); +WOLFSSL_API void wolfSSL_SetDhAgreeCtx(WOLFSSL* ssl, void *ctx); +WOLFSSL_API void* wolfSSL_GetDhAgreeCtx(WOLFSSL* ssl); +#endif /* !NO_DH */ + struct ed25519_key; typedef int (*CallbackEd25519Sign)(WOLFSSL* ssl, const unsigned char* in, unsigned int inSz, diff --git a/wolfssl/test.h b/wolfssl/test.h index 577fe4571..376d54532 100644 --- a/wolfssl/test.h +++ b/wolfssl/test.h @@ -2013,6 +2013,21 @@ static INLINE int myX25519SharedSecret(WOLFSSL* ssl, curve25519_key* otherKey, #endif /* HAVE_ECC */ +#ifndef NO_DH +static INLINE int myDhCallback(WOLFSSL* ssl, struct DhKey* key, + const unsigned char* priv, unsigned int privSz, + const unsigned char* pubKeyDer, unsigned int pubKeySz, + unsigned char* out, unsigned int* outlen, + void* ctx) +{ + (void)ctx; + (void)ssl; + /* return 0 on success */ + return wc_DhAgree(key, out, outlen, priv, privSz, pubKeyDer, pubKeySz); +}; + +#endif /* !NO_DH */ + #ifndef NO_RSA static INLINE int myRsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, @@ -2244,6 +2259,9 @@ static INLINE void SetupPkCallbacks(WOLFSSL_CTX* ctx, WOLFSSL* ssl) wolfSSL_CTX_SetEccVerifyCb(ctx, myEccVerify); wolfSSL_CTX_SetEccSharedSecretCb(ctx, myEccSharedSecret); #endif /* HAVE_ECC */ + #ifndef NO_DH + wolfSSL_CTX_SetDhAgreeCb(ctx, myDhCallback); + #endif #ifdef HAVE_ED25519 wolfSSL_CTX_SetEd25519SignCb(ctx, myEd25519Sign); wolfSSL_CTX_SetEd25519VerifyCb(ctx, myEd25519Verify);