diff --git a/examples/client/client.c b/examples/client/client.c index 6ee68ad04..a5dc47cd1 100644 --- a/examples/client/client.c +++ b/examples/client/client.c @@ -2923,6 +2923,13 @@ THREAD_RETURN WOLFSSL_THREAD client_test(void* args) err_sys("unable to get SSL object"); } +#ifndef NO_PSK + if (usePsk) { + #if defined(OPENSSL_EXTRA) && defined(WOLFSSL_TLS13) && defined(TEST_PSK_USE_SESSION) + SSL_set_psk_use_session_callback(ssl, my_psk_use_session_cb); + #endif + } +#endif #ifndef NO_CERTS if (useClientCert && loadCertKeyIntoSSLObj){ diff --git a/src/ssl.c b/src/ssl.c index 31ba5d432..56d74d9e0 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -14687,7 +14687,25 @@ int wolfSSL_set_compression(WOLFSSL* ssl) ssl->options.haveStaticECC, ssl->options.haveAnon, ssl->options.side); } - + #ifdef OPENSSL_EXTRA + /** + * set call back function for psk session use + * @param ssl a pointer to WOLFSSL structure + * @param cb a function pointer to wc_psk_use_session_cb + * @return none + */ + void wolfSSL_set_psk_use_session_callback(WOLFSSL* ssl, + wc_psk_use_session_cb_func cb) + { + WOLFSSL_ENTER("wolfSSL_set_psk_use_session_callback"); + + ssl->options.havePSK = 1; + ssl->options.session_psk_cb = cb; + + WOLFSSL_LEAVE("wolfSSL_set_psk_use_session_callback", WOLFSSL_SUCCESS); + } + #endif + void wolfSSL_CTX_set_psk_server_callback(WOLFSSL_CTX* ctx, wc_psk_server_callback cb) { @@ -22095,6 +22113,28 @@ void wolfSSL_SESSION_free(WOLFSSL_SESSION* session) FreeSession(session, 0); #endif } +/** +* set cipher to WOLFSSL_SESSION from WOLFSSL_CIPHER +* @param session a pointer to WOLFSSL_SESSION structure +* @param cipher a function pointer to WOLFSSL_CIPHER +* @return WOLFSSL_SUCCESS on success, otherwise WOLFSSL_FAILURE +*/ +int wolfSSL_SESSION_set_cipher(WOLFSSL_SESSION* session, + const WOLFSSL_CIPHER* cipher) +{ + WOLFSSL_ENTER("wolfSSL_SESSION_set_cipher"); + + /* sanity check */ + if (session == NULL || cipher == NULL) { + WOLFSSL_MSG("bad argument"); + return WOLFSSL_FAILURE; + } + session->cipherSuite0 = cipher->cipherSuite0; + session->cipherSuite = cipher->cipherSuite; + + WOLFSSL_LEAVE("wolfSSL_SESSION_set_cipher", WOLFSSL_SUCCESS); + return WOLFSSL_SUCCESS; +} #endif /* OPENSSL_EXTRA || HAVE_EXT_CACHE */ @@ -55912,21 +55952,7 @@ int wolfSSL_CTX_get_security_level(const WOLFSSL_CTX* ctx) return 0; } -#ifndef NO_WOLFSSL_STUB -/** - * set call back function for psk session use - * @param ssl a pointer to WOLFSSL structure - * @return none - */ -void wolfSSL_set_psk_use_session_callback(WOLFSSL* ssl, - wc_psk_use_session_cb_func cb) -{ - WOLFSSL_STUB("wolfSSL_set_psk_use_session_callback"); - (void)ssl; - (void)cb; -} -#endif /* NO_WOLFSSL_STUB */ /** * Determine whether a WOLFSSL_SESSION object can be used for resumption * @param s a pointer to WOLFSSL_SESSION structure diff --git a/src/tls13.c b/src/tls13.c index fd686a747..5a49a100d 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -2607,6 +2607,49 @@ static byte helloRetryRequestRandom[] = { #ifndef NO_WOLFSSL_CLIENT #if defined(HAVE_SESSION_TICKET) || !defined(NO_PSK) +#if defined(OPENSSL_EXTRA) && !defined(WOLFSSL_PSK_ONE_ID) && \ + !defined(NO_PSK) +/** +* convert mac algorithm to WOLFSSL_EVP_MD +* @param mac_alg mac algorithm +* @return const WOLFSSL_EVP_MD on sucessful, otherwise NULL +*/ +static const WOLFSSL_EVP_MD* ssl_handshake_md(const byte mac_alg) +{ + switch(mac_alg) { + case no_mac: + #ifndef NO_MD5 + case md5_mac: + return wolfSSL_EVP_md5(); + #endif + #ifndef NO_SHA + case sha_mac: + return wolfSSL_EVP_sha1(); + #endif + #ifdef WOLFSSL_SHA224 + case sha224_mac: + return wolfSSL_EVP_sha224(); + #endif + case sha256_mac: + return wolfSSL_EVP_sha256(); + #ifdef WOLFSSL_SHA384 + case sha384_mac: + return wolfSSL_EVP_sha384(); + #endif + #ifdef WOLFSSL_SHA512 + case sha512_mac: + return wolfSSL_EVP_sha512(); + #endif + case rmd_mac: + case blake2b_mac: + WOLFSSL_MSG("no suitable EVP_MD"); + return NULL; + default: + WOLFSSL_MSG("Unknown mac algorithm"); + return NULL; + } +} +#endif /* Setup pre-shared key based on the details in the extension data. * * ssl SSL/TLS object. @@ -2652,9 +2695,33 @@ static int SetupPskKey(WOLFSSL* ssl, PreSharedKey* psk) const char* cipherName = NULL; byte cipherSuite0 = TLS13_BYTE, cipherSuite = WOLFSSL_DEF_PSK_CIPHER; int cipherSuiteFlags = WOLFSSL_CIPHER_SUITE_FLAG_NONE; + + #ifdef OPENSSL_EXTRA + const unsigned char* id = NULL; + size_t idlen = 0; + WOLFSSL_SESSION* psksession = NULL; + const WOLFSSL_EVP_MD* handshake_md = NULL; + if (ssl->options.session_psk_cb != NULL) { + + if (ssl->msgsReceived.got_hello_retry_request >= 1) { + handshake_md = ssl_handshake_md(ssl->specs.mac_algorithm); + } + /* Get the pre-shared key. */ + if (!ssl->options.session_psk_cb(ssl, handshake_md, + &id, &idlen, &psksession)) { + wolfSSL_SESSION_free(psksession); + WOLFSSL_MSG("psk session callback failed"); + return PSK_KEY_ERROR; + } + } + + if (psksession == NULL && + #else /* Get the pre-shared key. */ - if (ssl->options.client_psk_tls13_cb != NULL) { + if ( + #endif + ssl->options.client_psk_tls13_cb != NULL) { ssl->arrays->psk_keySz = ssl->options.client_psk_tls13_cb(ssl, (char *)psk->identity, ssl->arrays->client_identity, MAX_PSK_ID_LEN, ssl->arrays->psk_key, MAX_PSK_KEY_LEN, @@ -2665,6 +2732,22 @@ static int SetupPskKey(WOLFSSL* ssl, PreSharedKey* psk) } } else { + #ifdef OPENSSL_EXTRA + if (psksession != NULL) { + if (idlen > MAX_PSK_KEY_LEN) { + WOLFSSL_MSG("psk key length is too long"); + return PSK_KEY_ERROR; + } + + ssl->arrays->psk_keySz = (word32)idlen; + XMEMCPY(ssl->arrays->psk_key, id, idlen); + cipherSuite0 = psksession->cipherSuite0; + cipherSuite = psksession->cipherSuite; + /* no need anymore */ + wolfSSL_SESSION_free(psksession); + } + else + #endif ssl->arrays->psk_keySz = ssl->options.client_psk_cb(ssl, (char *)psk->identity, ssl->arrays->client_identity, MAX_PSK_ID_LEN, ssl->arrays->psk_key, MAX_PSK_KEY_LEN); diff --git a/tests/api.c b/tests/api.c index be6c3091c..034c07a48 100644 --- a/tests/api.c +++ b/tests/api.c @@ -42918,7 +42918,6 @@ static void test_wolfSSL_set_psk_use_session_callback(void) ssl = SSL_new(ctx); AssertNotNull(ssl); - /* STUB */ SSL_set_psk_use_session_callback(ssl, my_psk_use_session_cb); AssertTrue(1); diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 9ceb7a1fd..7a8a87765 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -3506,6 +3506,9 @@ typedef struct Options { #ifndef NO_PSK wc_psk_client_callback client_psk_cb; wc_psk_server_callback server_psk_cb; +#ifdef OPENSSL_EXTRA + wc_psk_use_session_cb_func session_psk_cb; +#endif #ifdef WOLFSSL_TLS13 wc_psk_client_tls13_callback client_psk_tls13_cb; /* client callback */ wc_psk_server_tls13_callback server_psk_tls13_cb; /* server callback */ diff --git a/wolfssl/openssl/ssl.h b/wolfssl/openssl/ssl.h index 850d3a230..7e8415706 100644 --- a/wolfssl/openssl/ssl.h +++ b/wolfssl/openssl/ssl.h @@ -333,6 +333,7 @@ typedef STACK_OF(ACCESS_DESCRIPTION) AUTHORITY_INFO_ACCESS; #define SSL_SESSION_up_ref wolfSSL_SESSION_up_ref #define SSL_SESSION_dup wolfSSL_SESSION_dup #define SSL_SESSION_free wolfSSL_SESSION_free +#define SSL_SESSION_set_cipher wolfSSL_SESSION_set_cipher #define SSL_is_init_finished wolfSSL_is_init_finished #define SSL_get_version wolfSSL_get_version diff --git a/wolfssl/ssl.h b/wolfssl/ssl.h index 12677344b..5e8e1f7f0 100644 --- a/wolfssl/ssl.h +++ b/wolfssl/ssl.h @@ -1332,6 +1332,8 @@ WOLFSSL_API int wolfSSL_SESSION_up_ref(WOLFSSL_SESSION* session); WOLFSSL_API WOLFSSL_SESSION* wolfSSL_SESSION_dup(WOLFSSL_SESSION* session); WOLFSSL_API WOLFSSL_SESSION* wolfSSL_SESSION_new(void); WOLFSSL_API void wolfSSL_SESSION_free(WOLFSSL_SESSION* session); +WOLFSSL_API int wolfSSL_SESSION_set_cipher(WOLFSSL_SESSION* session, + const WOLFSSL_CIPHER* cipher); WOLFSSL_API int wolfSSL_is_init_finished(WOLFSSL*); WOLFSSL_API const char* wolfSSL_get_version(const WOLFSSL*); @@ -2179,6 +2181,13 @@ enum { /* ssl Constants */ wc_psk_client_callback); WOLFSSL_API void wolfSSL_set_psk_client_callback(WOLFSSL*, wc_psk_client_callback); + #ifdef OPENSSL_EXTRA + typedef int (*wc_psk_use_session_cb_func)(WOLFSSL* ssl, + const WOLFSSL_EVP_MD* md, const unsigned char **id, + size_t* idlen, WOLFSSL_SESSION **sess); + WOLFSSL_API void wolfSSL_set_psk_use_session_callback(WOLFSSL* ssl, + wc_psk_use_session_cb_func cb); + #endif #ifdef WOLFSSL_TLS13 typedef unsigned int (*wc_psk_client_tls13_callback)(WOLFSSL*, const char*, char*, unsigned int, unsigned char*, unsigned int, const char**); @@ -4396,13 +4405,6 @@ WOLFSSL_API int wolfSSL_EVP_PKEY_param_check(WOLFSSL_EVP_PKEY_CTX* ctx); WOLFSSL_API void wolfSSL_CTX_set_security_level(WOLFSSL_CTX* ctx, int level); WOLFSSL_API int wolfSSL_CTX_get_security_level(const WOLFSSL_CTX* ctx); -typedef int (*wc_psk_use_session_cb_func)(WOLFSSL* ssl, const WOLFSSL_EVP_MD* md, - const unsigned char **id, - size_t* idlen, - WOLFSSL_SESSION **sess); -WOLFSSL_API void wolfSSL_set_psk_use_session_callback(WOLFSSL* ssl, - wc_psk_use_session_cb_func cb); - WOLFSSL_API int wolfSSL_SESSION_is_resumable(const WOLFSSL_SESSION *s); WOLFSSL_API void wolfSSL_CRYPTO_free(void *str, const char *file, int line); diff --git a/wolfssl/test.h b/wolfssl/test.h index 1d4cb9cf2..8532125d7 100644 --- a/wolfssl/test.h +++ b/wolfssl/test.h @@ -1515,19 +1515,77 @@ static WC_INLINE unsigned int my_psk_server_tls13_cb(WOLFSSL* ssl, return 32; /* length of key in octets or 0 for error */ } +#if defined(OPENSSL_ALL) && !defined(NO_CERTS) && \ + !defined(NO_FILESYSTEM) +static unsigned char local_psk[32]; +#endif static WC_INLINE int my_psk_use_session_cb(WOLFSSL* ssl, const WOLFSSL_EVP_MD* md, const unsigned char **id, size_t* idlen, WOLFSSL_SESSION **sess) { +#if defined(OPENSSL_ALL) && !defined(NO_CERTS) && \ + !defined(NO_FILESYSTEM) + int i; + int b = 0x01; + WOLFSSL_SESSION* lsess; + char buf[256]; + const char* cipher_id = "TLS13-AES128-GCM-SHA256"; + const SSL_CIPHER* cipher = NULL; + STACK_OF(SSL_CIPHER) *supportedCiphers = NULL; + int numCiphers = 0; + (void)ssl; + (void)md; + + printf("use psk session callback \n"); + + lsess = wolfSSL_SESSION_new(); + if (lsess == NULL) { + return 0; + } + supportedCiphers = SSL_get_ciphers(ssl); + numCiphers = sk_num(supportedCiphers); + + for (i = 0; i < numCiphers; ++i) { + + if ((cipher = (const WOLFSSL_CIPHER*)sk_value(supportedCiphers, i))) { + SSL_CIPHER_description(cipher, buf, sizeof(buf)); + } + + if (XMEMCMP(cipher_id, buf, XSTRLEN(cipher_id)) == 0) { + break; + } + } + + if (i != numCiphers) { + SSL_SESSION_set_cipher(lsess, cipher); + for (i = 0; i < 32; i++, b += 0x22) { + if (b >= 0x100) + b = 0x01; + local_psk[i] = b; + } + + *id = local_psk; + *idlen = 32; + *sess = lsess; + + return 1; + } + else { + *id = NULL; + *idlen = 0; + *sess = NULL; + return 0; + } +#else (void)ssl; (void)md; (void)id; (void)idlen; (void)sess; - + return 0; +#endif } - #endif /* !NO_PSK */