diff --git a/src/internal.c b/src/internal.c index 71825664d..2fdb3864a 100755 --- a/src/internal.c +++ b/src/internal.c @@ -4289,6 +4289,8 @@ int InitSSL(WOLFSSL* ssl, WOLFSSL_CTX* ctx, int writeDup) #ifdef WOLFSSL_MULTICAST if (ctx->haveMcast) { + int i; + ssl->options.haveMcast = 1; ssl->options.mcastID = ctx->mcastID; @@ -4300,6 +4302,9 @@ int InitSSL(WOLFSSL* ssl, WOLFSSL_CTX* ctx, int writeDup) ssl->options.acceptState = ACCEPT_THIRD_REPLY_DONE; ssl->options.handShakeState = HANDSHAKE_DONE; ssl->options.handShakeDone = 1; + + for (i = 0; i < WOLFSSL_DTLS_PEERSEQ_SZ; i++) + ssl->keys.peerSeq[i].peerId = INVALID_PEER_ID; } #endif @@ -9609,12 +9614,31 @@ static INLINE int DtlsCheckWindow(WOLFSSL* ssl) word16 cur_hi, next_hi; word32 cur_lo, next_lo, diff; int curLT; - WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq; + WOLFSSL_DTLS_PEERSEQ* peerSeq = NULL; + if (!ssl->options.haveMcast) + peerSeq = ssl->keys.peerSeq; + else { #ifdef WOLFSSL_MULTICAST - if (ssl->options.haveMcast) - peerSeq += ssl->keys.curPeerId; + WOLFSSL_DTLS_PEERSEQ* p; + int i; + + for (i = 0, p = ssl->keys.peerSeq; + i < WOLFSSL_DTLS_PEERSEQ_SZ; + i++, p++) { + + if (p->peerId == ssl->keys.curPeerId) { + peerSeq = p; + break; + } + } + + if (peerSeq == NULL) { + WOLFSSL_MSG("Couldn't find that peer ID to check window."); + return 0; + } #endif + } if (ssl->keys.curEpoch == peerSeq->nextEpoch) { next_hi = peerSeq->nextSeq_hi; @@ -9692,10 +9716,30 @@ static INLINE int DtlsUpdateWindow(WOLFSSL* ssl) word16 cur_hi; WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq; + if (!ssl->options.haveMcast) + peerSeq = ssl->keys.peerSeq; + else { #ifdef WOLFSSL_MULTICAST - if (ssl->options.haveMcast) - peerSeq += ssl->keys.curPeerId; + WOLFSSL_DTLS_PEERSEQ* p; + int i; + + peerSeq = NULL; + for (i = 0, p = ssl->keys.peerSeq; + i < WOLFSSL_DTLS_PEERSEQ_SZ; + i++, p++) { + + if (p->peerId == ssl->keys.curPeerId) { + peerSeq = p; + break; + } + } + + if (peerSeq == NULL) { + WOLFSSL_MSG("Couldn't find that peer ID to update window."); + return 0; + } #endif + } if (ssl->keys.curEpoch == peerSeq->nextEpoch) { next_hi = &peerSeq->nextSeq_hi; diff --git a/src/ssl.c b/src/ssl.c index 95a5553da..2aa775375 100755 --- a/src/ssl.c +++ b/src/ssl.c @@ -846,13 +846,13 @@ int wolfSSL_dtls_set_mtu(WOLFSSL* ssl, word16 newMtu) #if defined(WOLFSSL_MULTICAST) -int wolfSSL_CTX_mcast_set_member_id(WOLFSSL_CTX* ctx, byte id) +int wolfSSL_CTX_mcast_set_member_id(WOLFSSL_CTX* ctx, word16 id) { int ret = 0; WOLFSSL_ENTER("wolfSSL_CTX_mcast_set_member_id()"); - if (ctx == NULL || id >= WOLFSSL_MULTICAST_PEERS) + if (ctx == NULL || id > 255) ret = BAD_FUNC_ARG; if (ret == 0) { @@ -935,6 +935,56 @@ int wolfSSL_set_secret(WOLFSSL* ssl, unsigned short epoch, return ret; } + +int wolfSSL_mcast_peer_add(WOLFSSL* ssl, word16 peerId, int remove) +{ + WOLFSSL_DTLS_PEERSEQ* p = NULL; + int ret = SSL_SUCCESS; + int i; + + WOLFSSL_ENTER("wolfSSL_mcast_peer_add()"); + if (ssl == NULL || peerId > 255) + return BAD_FUNC_ARG; + + if (!remove) { + /* Make sure it isn't already present, while keeping the first + * open spot. */ + for (i = 0; i < WOLFSSL_DTLS_PEERSEQ_SZ; i++) { + if (ssl->keys.peerSeq[i].peerId == INVALID_PEER_ID) + p = &ssl->keys.peerSeq[i]; + if (ssl->keys.peerSeq[i].peerId == peerId) { + WOLFSSL_MSG("Peer ID already in multicast peer list."); + p = NULL; + } + } + + if (p != NULL) { + XMEMSET(p, 0, sizeof(WOLFSSL_DTLS_PEERSEQ)); + p->peerId = peerId; + } + else { + WOLFSSL_MSG("No room in peer list."); + ret = -1; + } + } + else { + for (i = 0; i < WOLFSSL_DTLS_PEERSEQ_SZ; i++) { + if (ssl->keys.peerSeq[i].peerId == peerId) + p = &ssl->keys.peerSeq[i]; + } + + if (p != NULL) { + p->peerId = INVALID_PEER_ID; + } + else { + WOLFSSL_MSG("Peer not found in list."); + } + } + + WOLFSSL_LEAVE("wolfSSL_mcast_peer_add()", ret); + return ret; +} + #endif /* WOLFSSL_MULTICAST */ @@ -1549,7 +1599,7 @@ int wolfSSL_read(WOLFSSL* ssl, void* data, int sz) #ifdef WOLFSSL_MULTICAST -int wolfSSL_mcast_read(WOLFSSL* ssl, unsigned char* id, void* data, int sz) +int wolfSSL_mcast_read(WOLFSSL* ssl, word16* id, void* data, int sz) { int ret = 0; diff --git a/tests/api.c b/tests/api.c index e3b4a0a5e..94764648e 100644 --- a/tests/api.c +++ b/tests/api.c @@ -2358,7 +2358,7 @@ static void test_wolfSSL_mcast(void) byte serverRandom[32]; byte suite[2] = {0, 0xfe}; /* WDM_WITH_NULL_SHA256 */ byte buf[256]; - byte newId; + word16 newId; ctx = wolfSSL_CTX_new(wolfDTLSv1_2_client_method()); AssertNotNull(ctx); diff --git a/wolfssl/internal.h b/wolfssl/internal.h index daf8b7df6..e14b2ec63 100755 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -1168,6 +1168,8 @@ enum Misc { NO_COPY = 0, /* should we copy static buffer for write */ COPY = 1, /* should we copy static buffer for write */ + INVALID_PEER_ID = 0xFFFF, /* Initialize value for peer ID. */ + PREV_ORDER = -1, /* Sequence number is in previous epoch. */ PEER_ORDER = 1, /* Peer sequence number for verify. */ CUR_ORDER = 0 /* Current sequence number. */ @@ -1743,8 +1745,10 @@ typedef struct WOLFSSL_DTLS_PEERSEQ { word32 prevWindow[WOLFSSL_DTLS_WINDOW_WORDS]; /* Sliding window for old epoch */ - word16 prevSeq_hi; /* Next sequence in allowed old epoch */ word32 prevSeq_lo; + word16 prevSeq_hi; /* Next sequence in allowed old epoch */ + + word16 peerId; } WOLFSSL_DTLS_PEERSEQ; @@ -2893,7 +2897,7 @@ typedef struct Options { byte asyncState; /* sub-state for enum asyncState */ byte buildMsgState; /* sub-state for enum buildMsgState */ #ifdef WOLFSSL_MULTICAST - byte mcastID; /* Multicast group ID */ + word16 mcastID; /* Multicast group ID */ #endif #ifndef NO_DH word16 minDhKeySz; /* minimum DH key size */ diff --git a/wolfssl/ssl.h b/wolfssl/ssl.h index 831bee2b1..1026129d7 100644 --- a/wolfssl/ssl.h +++ b/wolfssl/ssl.h @@ -509,12 +509,13 @@ WOLFSSL_API int wolfSSL_dtls_set_sctp(WOLFSSL*); WOLFSSL_API int wolfSSL_CTX_dtls_set_mtu(WOLFSSL_CTX*, unsigned short); WOLFSSL_API int wolfSSL_dtls_set_mtu(WOLFSSL*, unsigned short); -WOLFSSL_API int wolfSSL_CTX_mcast_set_member_id(WOLFSSL_CTX*, unsigned char); +WOLFSSL_API int wolfSSL_CTX_mcast_set_member_id(WOLFSSL_CTX*, unsigned short); WOLFSSL_API int wolfSSL_set_secret(WOLFSSL*, unsigned short, const unsigned char*, unsigned int, const unsigned char*, const unsigned char*, const unsigned char*); -WOLFSSL_API int wolfSSL_mcast_read(WOLFSSL*, unsigned char*, void*, int); +WOLFSSL_API int wolfSSL_mcast_read(WOLFSSL*, unsigned short*, void*, int); +WOLFSSL_API int wolfSSL_mcast_peer_add(WOLFSSL*, unsigned short, int); WOLFSSL_API int wolfSSL_ERR_GET_REASON(unsigned long err); WOLFSSL_API char* wolfSSL_ERR_error_string(unsigned long,char*);