diff --git a/cyassl/internal.h b/cyassl/internal.h index 6c386d71f..e7ca9dddc 100644 --- a/cyassl/internal.h +++ b/cyassl/internal.h @@ -884,18 +884,18 @@ int SetCipherList(Suites*, const char* list); CYASSL_LOCAL void EmbedOcspRespFree(void*, byte*); #endif -#endif -#ifdef CYASSL_DTLS - CYASSL_LOCAL - int EmbedReceiveFrom(CYASSL *ssl, char *buf, int sz, void *ctx); - CYASSL_LOCAL - int EmbedSendTo(CYASSL *ssl, char *buf, int sz, void *ctx); - CYASSL_LOCAL - int EmbedGenerateCookie(byte *buf, int sz, void *ctx); - CYASSL_LOCAL - int IsUDP(void*); -#endif + #ifdef CYASSL_DTLS + CYASSL_LOCAL + int EmbedReceiveFrom(CYASSL *ssl, char *buf, int sz, void *ctx); + CYASSL_LOCAL + int EmbedSendTo(CYASSL *ssl, char *buf, int sz, void *ctx); + CYASSL_LOCAL + int EmbedGenerateCookie(CYASSL* ssl, byte *buf, int sz, void *ctx); + CYASSL_LOCAL + int IsUDP(void*); + #endif /* CYASSL_DTLS */ +#endif /* CYASSL_USER_IO */ /* CyaSSL Cipher type just points back to SSL */ @@ -1089,6 +1089,9 @@ struct CYASSL_CTX { byte groupMessages; /* group handshake messages before sending */ CallbackIORecv CBIORecv; CallbackIOSend CBIOSend; +#ifdef CYASSL_DTLS + CallbackGenCookie CBIOCookie; /* gen cookie callback */ +#endif VerifyCallback verifyCallback; /* cert verification callback */ word32 timeout; /* session timeout */ #ifdef HAVE_ECC @@ -1623,6 +1626,7 @@ struct CYASSL { int dtls_timeout; DtlsPool* dtls_pool; DtlsMsg* dtls_msg_list; + void* IOCB_CookieCtx; /* gen cookie ctx */ #endif #ifdef CYASSL_CALLBACKS HandShakeInfo handShakeInfo; /* info saved during handshake */ diff --git a/cyassl/ssl.h b/cyassl/ssl.h index cef60f5f3..8f0eafc85 100644 --- a/cyassl/ssl.h +++ b/cyassl/ssl.h @@ -809,6 +809,11 @@ CYASSL_API void CyaSSL_SetIOWriteCtx(CYASSL* ssl, void *ctx); CYASSL_API void CyaSSL_SetIOReadFlags( CYASSL* ssl, int flags); CYASSL_API void CyaSSL_SetIOWriteFlags(CYASSL* ssl, int flags); +typedef int (*CallbackGenCookie)(CYASSL* ssl, unsigned char* buf, int sz, + void* ctx); +CYASSL_API void CyaSSL_CTX_SetGenCookie(CYASSL_CTX*, CallbackGenCookie); +CYASSL_API void CyaSSL_SetCookieCtx(CYASSL* ssl, void *ctx); + typedef int (*CallbackIOOcsp)(void*, const char*, int, unsigned char*, int, unsigned char**); typedef void (*CallbackIOOcspRespFree)(void*,unsigned char*); diff --git a/src/internal.c b/src/internal.c index a61148285..b7596b8ac 100644 --- a/src/internal.c +++ b/src/internal.c @@ -367,14 +367,16 @@ int InitSSL_Ctx(CYASSL_CTX* ctx, CYASSL_METHOD* method) #ifdef CYASSL_DTLS if (method->version.major == DTLS_MAJOR && method->version.minor >= DTLSv1_2_MINOR) { - ctx->CBIORecv = EmbedReceiveFrom; - ctx->CBIOSend = EmbedSendTo; + ctx->CBIORecv = EmbedReceiveFrom; + ctx->CBIOSend = EmbedSendTo; + ctx->CBIOCookie = EmbedGenerateCookie; } #endif #else /* user will set */ - ctx->CBIORecv = NULL; - ctx->CBIOSend = NULL; + ctx->CBIORecv = NULL; + ctx->CBIOSend = NULL; + ctx->CBIOCookie = NULL; #endif ctx->partialWrite = 0; ctx->verifyCallback = 0; @@ -1227,6 +1229,9 @@ int InitSSL(CYASSL* ssl, CYASSL_CTX* ctx) ssl->IOCB_ReadCtx = &ssl->rfd; /* prevent invalid pointer access if not */ ssl->IOCB_WriteCtx = &ssl->wfd; /* correctly set */ +#ifdef CYASSL_DTLS + ssl->IOCB_CookieCtx = NULL; /* we don't use for default cb */ +#endif #ifndef NO_OLD_TLS #ifndef NO_MD5 @@ -2196,6 +2201,11 @@ static int Receive(CYASSL* ssl, byte* buf, word32 sz) { int recvd; + if (ssl->ctx->CBIORecv == NULL) { + CYASSL_MSG("Your IO Recv callback is null, please set"); + return -1; + } + retry: recvd = ssl->ctx->CBIORecv(ssl, (char *)buf, (int)sz, ssl->IOCB_ReadCtx); if (recvd < 0) @@ -2290,6 +2300,11 @@ void ShrinkInputBuffer(CYASSL* ssl, int forcedFree) int SendBuffered(CYASSL* ssl) { + if (ssl->ctx->CBIOSend == NULL) { + CYASSL_MSG("Your IO Send callback is null, please set"); + return SOCKET_ERROR_E; + } + while (ssl->buffers.outputBuffer.length > 0) { int sent = ssl->ctx->CBIOSend(ssl, (char*)ssl->buffers.outputBuffer.buffer + @@ -9016,8 +9031,12 @@ int SetCipherList(Suites* s, const char* list) return BUFFER_ERROR; if (i + b > totalSz) return INCOMPLETE_DATA; - if ((EmbedGenerateCookie(cookie, COOKIE_SZ, ssl) - != COOKIE_SZ) + if (ssl->ctx->CBIORecv == NULL) { + CYASSL_MSG("Your Cookie callback is null, please set"); + return COOKIE_ERROR; + } + if ((ssl->ctx->CBIOCookie(ssl, cookie, COOKIE_SZ, + ssl->IOCB_CookieCtx) != COOKIE_SZ) || (b != COOKIE_SZ) || (XMEMCMP(cookie, input + i, b) != 0)) { return COOKIE_ERROR; @@ -9327,7 +9346,12 @@ int SetCipherList(Suites* s, const char* list) output[idx++] = ssl->chVersion.minor; output[idx++] = cookieSz; - if ((ret = EmbedGenerateCookie(output + idx, cookieSz, ssl)) < 0) + if (ssl->ctx->CBIORecv == NULL) { + CYASSL_MSG("Your Cookie callback is null, please set"); + return COOKIE_ERROR; + } + if ((ret = ssl->ctx->CBIOCookie(ssl, output + idx, cookieSz, + ssl->IOCB_CookieCtx)) < 0) return ret; HashOutput(ssl, output, sendSz, 0); diff --git a/src/io.c b/src/io.c index 6edb8008d..9e4a763c9 100644 --- a/src/io.c +++ b/src/io.c @@ -429,9 +429,8 @@ int EmbedSendTo(CYASSL* ssl, char *buf, int sz, void *ctx) /* The DTLS Generate Cookie callback * return : number of bytes copied into buf, or error */ -int EmbedGenerateCookie(byte *buf, int sz, void *ctx) +int EmbedGenerateCookie(CYASSL* ssl, byte *buf, int sz, void *ctx) { - CYASSL* ssl = (CYASSL*)ctx; int sd = ssl->wfd; struct sockaddr_in peer; XSOCKLENT peerSz = sizeof(peer); @@ -439,6 +438,8 @@ int EmbedGenerateCookie(byte *buf, int sz, void *ctx) int cookieSrcSz = 0; Sha sha; + (void)ctx; + if (getpeername(sd, (struct sockaddr*)&peer, &peerSz) != 0) { CYASSL_MSG("getpeername failed in EmbedGenerateCookie"); return GEN_COOKIE_E; @@ -783,6 +784,23 @@ CYASSL_API void CyaSSL_SetIOWriteFlags(CYASSL* ssl, int flags) ssl->wflags = flags; } + +#ifdef CYASSL_DTLS + +CYASSL_API void CyaSSL_CTX_SetGenCookie(CYASSL_CTX* ctx, CallbackGenCookie cb) +{ + ctx->CBIOCookie = cb; +} + + +CYASSL_API void CyaSSL_SetCookieCtx(CYASSL* ssl, void *ctx) +{ + ssl->IOCB_CookieCtx = ctx; +} + +#endif /* CYASSL_DTLS */ + + #ifdef HAVE_OCSP CYASSL_API void CyaSSL_SetIOOcsp(CYASSL_CTX* ctx, CallbackIOOcsp cb)