diff --git a/src/sniffer.c b/src/sniffer.c index fbab33bf1..227355c6c 100644 --- a/src/sniffer.c +++ b/src/sniffer.c @@ -248,7 +248,12 @@ static const char* const msgTable[] = "Extended Master Secret Hash Error", "Handshake Message Split Across TLS Records", "ECC Private Decode Error", - "ECC Public Decode Error" + "ECC Public Decode Error", + + /* 86 */ + "Watch callback not set", + "Watch hash failed", + "Watch callback failed" }; @@ -416,6 +421,12 @@ static SSLStats SnifferStats; static wolfSSL_Mutex StatsMutex; #endif +#ifdef WOLFSSL_SNIFFER_WATCH +/* Watch Key Callback */ +static SSLWatchCb WatchCb; +static void* WatchCbCtx = NULL; +#endif + static void UpdateMissedDataSessions(void) { @@ -1097,6 +1108,8 @@ static void SetError(int idx, char* error, SnifferSession* session, int fatal) } +#ifndef WOLFSSL_SNIFFER_WATCH + /* See if this IPV4 network order address has been registered */ /* return 1 is true, 0 is false */ static int IsServerRegistered(word32 addr) @@ -1144,6 +1157,8 @@ static int IsPortRegistered(word32 port) return ret; } +#endif + /* Get SnifferServer from IP and Port */ static SnifferServer* GetSnifferServer(IpInfo* ipInfo, TcpInfo* tcpInfo) @@ -1153,6 +1168,8 @@ static SnifferServer* GetSnifferServer(IpInfo* ipInfo, TcpInfo* tcpInfo) wc_LockMutex(&ServerListMutex); sniffer = ServerList; + +#ifndef WOLFSSL_SNIFFER_WATCH while (sniffer) { if (sniffer->port == tcpInfo->srcPort && sniffer->server == ipInfo->src) break; @@ -1160,6 +1177,10 @@ static SnifferServer* GetSnifferServer(IpInfo* ipInfo, TcpInfo* tcpInfo) break; sniffer = sniffer->next; } +#else + (void)ipInfo; + (void)tcpInfo; +#endif wc_UnLockMutex(&ServerListMutex); @@ -1209,8 +1230,8 @@ static SnifferSession* GetSnifferSession(IpInfo* ipInfo, TcpInfo* tcpInfo) /* determine side */ if (session) { - if (ipInfo->dst == session->context->server && - tcpInfo->dstPort == session->context->port) + if (ipInfo->dst == session->server && + tcpInfo->dstPort == session->srvPort) session->flags.side = WOLFSSL_SERVER_END; else session->flags.side = WOLFSSL_CLIENT_END; @@ -1220,7 +1241,7 @@ static SnifferSession* GetSnifferSession(IpInfo* ipInfo, TcpInfo* tcpInfo) } -#ifdef HAVE_SNI +#if defined(HAVE_SNI) || defined(WOLFSSL_SNIFFER_WATCH) static int LoadKeyFile(byte** keyBuf, word32* keyBufSz, const char* keyFile, int typeKey, @@ -1295,6 +1316,32 @@ static int LoadKeyFile(byte** keyBuf, word32* keyBufSz, #endif +#ifdef WOLFSSL_SNIFFER_WATCH + +static int CreateWatchSnifferServer(char* error) +{ + SnifferServer* sniffer; + + sniffer = (SnifferServer*)malloc(sizeof(SnifferServer)); + if (sniffer == NULL) { + SetError(MEMORY_STR, error, NULL, 0); + return -1; + } + InitSnifferServer(sniffer); + sniffer->ctx = SSL_CTX_new(TLSv1_2_client_method()); + if (!sniffer->ctx) { + SetError(MEMORY_STR, error, NULL, 0); + FreeSnifferServer(sniffer); + return -1; + } + ServerList = sniffer; + + return 0; +} + +#endif + + static int SetNamedPrivateKey(const char* name, const char* address, int port, const char* keyFile, int typeKey, const char* password, char* error) { @@ -1473,10 +1520,12 @@ static int CheckIpHdr(IpHdr* iphdr, IpInfo* info, int length, char* error) return -1; } +#ifndef WOLFSSL_SNIFFER_WATCH if (!IsServerRegistered(iphdr->src) && !IsServerRegistered(iphdr->dst)) { SetError(SERVER_NOT_REG_STR, error, NULL, 0); return -1; } +#endif info->length = IP_HL(iphdr); info->total = ntohs(iphdr->length); @@ -1507,10 +1556,14 @@ static int CheckTcpHdr(TcpHdr* tcphdr, TcpInfo* info, char* error) if (info->ack) info->ackNumber = ntohl(tcphdr->ack); +#ifndef WOLFSSL_SNIFFER_WATCH if (!IsPortRegistered(info->srcPort) && !IsPortRegistered(info->dstPort)) { SetError(SERVER_PORT_NOT_REG_STR, error, NULL, 0); return -1; } +#else + (void)error; +#endif return 0; } @@ -2254,6 +2307,56 @@ static int ProcessClientHello(const byte* input, int* sslBytes, } +#ifdef WOLFSSL_SNIFFER_WATCH + +/* Process Certificate */ +static int ProcessCertificate(const byte* input, int* sslBytes, + SnifferSession* session, char* error) +{ + Sha256 sha; + word32 certSz; + int ret; + byte digest[SHA256_DIGEST_SIZE]; + + (void)sslBytes; + + /* If the receiver is the server, this is the client certificate message, + * and it should be ignored at this point. */ + if (session->flags.side == WOLFSSL_SERVER_END) + return 0; + + if (WatchCb == NULL) { + SetError(WATCH_CB_MISSING_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + + input += CERT_HEADER_SZ; + ato24(input, &certSz); + input += OPAQUE24_LEN; + + ret = wc_InitSha256(&sha); + if (ret == 0) + ret = wc_Sha256Update(&sha, input, certSz); + if (ret == 0) + ret = wc_Sha256Final(&sha, digest); + if (ret != 0) { + SetError(WATCH_HASH_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + + ret = WatchCb((void*)session, digest, sizeof(digest), input, certSz, + WatchCbCtx, error); + if (ret != 0) { + SetError(WATCH_FAIL_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + + return 0; +} + +#endif + + /* Process Finished */ static int ProcessFinished(const byte* input, int size, int* sslBytes, SnifferSession* session, char* error) @@ -2374,6 +2477,9 @@ static int DoHandShake(const byte* input, int* sslBytes, INC_STAT(SnifferStats.sslClientAuthConns); #endif } +#ifdef WOLFSSL_SNIFFER_WATCH + ret = ProcessCertificate(input, sslBytes, session, error); +#endif break; case server_hello_done: Trace(GOT_SERVER_HELLO_DONE_STR); @@ -2708,12 +2814,10 @@ static SnifferSession* CreateSession(IpInfo* ipInfo, TcpInfo* tcpInfo, wc_UnLockMutex(&SessionMutex); - /* determine headed side */ - if (ipInfo->dst == session->context->server && - tcpInfo->dstPort == session->context->port) - session->flags.side = WOLFSSL_SERVER_END; - else - session->flags.side = WOLFSSL_CLIENT_END; + /* CreateSession is called in response to a SYN packet, we know this + * is headed to the server. Also we know the server is one we care + * about as we've passed the GetSnifferServer() successfully. */ + session->flags.side = WOLFSSL_SERVER_END; return session; } @@ -4019,5 +4123,63 @@ int ssl_ReadResetStatistics(SSLStats* stats) #endif /* WOLFSSL_SNIFFER_STATS */ +#ifdef WOLFSSL_SNIFFER_WATCH + +int ssl_SetWatchKeyCallback(SSLWatchCb cb, char* error) +{ + WatchCb = cb; + return CreateWatchSnifferServer(error); +} + + +int ssl_SetWatchKeyCtx(void* ctx, char* error) +{ + (void)error; + WatchCbCtx = ctx; + return 0; +} + + +int ssl_SetWatchKey(void* vSniffer, const char* keyFile, int keyType, + const char* password, char* error) +{ + SnifferSession* sniffer; + byte* keyBuf = NULL; + word32 keyBufSz = 0; + int ret; + + if (vSniffer == NULL) { + return -1; + } + if (keyFile == NULL) { + return -1; + } + + sniffer = (SnifferSession*)vSniffer; + /* Remap the keyType from what the user can use to + * what LoadKeyFile expects. */ + keyType = (keyType == FILETYPE_PEM) ? WOLFSSL_FILETYPE_PEM : + WOLFSSL_FILETYPE_ASN1; + + ret = LoadKeyFile(&keyBuf, &keyBufSz, keyFile, keyType, password); + if (ret < 0) { + SetError(KEY_FILE_STR, error, NULL, 0); + free(keyBuf); + return -1; + } + + ret = wolfSSL_use_PrivateKey_buffer(sniffer->sslServer, + keyBuf, keyBufSz, WOLFSSL_FILETYPE_ASN1); + if (ret != WOLFSSL_SUCCESS) { + SetError(KEY_FILE_STR, error, sniffer, FATAL_ERROR_STATE); + free(keyBuf); + return -1; + } + + return 0; +} + +#endif /* WOLFSSL_SNIFFER_WATCH */ + #endif /* WOLFSSL_SNIFFER */ #endif /* WOLFCRYPT_ONLY */ diff --git a/sslSniffer/sslSnifferTest/snifftest.c b/sslSniffer/sslSnifferTest/snifftest.c index 0bdf4c718..46acf836f 100644 --- a/sslSniffer/sslSnifferTest/snifftest.c +++ b/sslSniffer/sslSnifferTest/snifftest.c @@ -170,6 +170,27 @@ static char* iptos(unsigned int addr) } +#ifdef WOLFSSL_SNIFFER_WATCH + +static int myWatchCb(void* vSniffer, + const unsigned char* certHash, unsigned int certHashSz, + const unsigned char* cert, unsigned int certSz, + void* ctx, char* error) +{ + (void)certHash; + (void)certHashSz; + (void)cert; + (void)certSz; + (void)ctx; + + return ssl_SetWatchKey(vSniffer, + "../../certs/server-key.pem", + FILETYPE_PEM, NULL, error); +} + +#endif + + int main(int argc, char** argv) { int ret = 0; @@ -193,6 +214,9 @@ int main(int argc, char** argv) #endif ssl_Trace("./tracefile.txt", err); ssl_EnableRecovery(1, -1, err); +#ifdef WOLFSSL_SNIFFER_WATCH + ssl_SetWatchKeyCallback(myWatchCb, err); +#endif if (argc == 1) { /* normal case, user chooses device and port */ @@ -275,6 +299,7 @@ int main(int argc, char** argv) ret = pcap_setfilter(pcap, &fp); if (ret != 0) printf("pcap_setfilter failed %s\n", pcap_geterr(pcap)); +#ifndef WOLFSSL_SNIFFER_WATCH ret = ssl_SetPrivateKey(server, port, "../../certs/server-key.pem", FILETYPE_PEM, NULL, err); if (ret != 0) { @@ -298,6 +323,7 @@ int main(int argc, char** argv) } } } +#endif #endif } else if (argc >= 3) { diff --git a/wolfssl/sniffer.h b/wolfssl/sniffer.h index 7272efce3..090fdd9be 100644 --- a/wolfssl/sniffer.h +++ b/wolfssl/sniffer.h @@ -167,6 +167,23 @@ WOLFSSL_API SSL_SNIFFER_API int ssl_ReadResetStatistics(SSLStats* stats); +typedef int (*SSLWatchCb)(void* vSniffer, + const unsigned char* certHash, unsigned int certHashSz, + const unsigned char* cert, unsigned int certSz, + void* ctx, char* error); + +WOLFSSL_API +SSL_SNIFFER_API int ssl_SetWatchKeyCallback(SSLWatchCb cb, char* error); + +WOLFSSL_API +SSL_SNIFFER_API int ssl_SetWatchKeyCtx(void* ctx, char* error); + +WOLFSSL_API +SSL_SNIFFER_API int ssl_SetWatchKey(void* vSniffer, + const char* keyFile, int keyType, + const char* password, char* error); + + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/wolfssl/sniffer_error.h b/wolfssl/sniffer_error.h index 68ebf321a..844f278a0 100644 --- a/wolfssl/sniffer_error.h +++ b/wolfssl/sniffer_error.h @@ -121,6 +121,9 @@ #define SPLIT_HANDSHAKE_MSG_STR 83 #define ECC_DECODE_STR 84 #define ECC_PUB_DECODE_STR 85 +#define WATCH_CB_MISSING_STR 86 +#define WATCH_HASH_STR 87 +#define WATCH_FAIL_STR 88 /* !!!! also add to msgTable in sniffer.c and .rc file !!!! */ diff --git a/wolfssl/sniffer_error.rc b/wolfssl/sniffer_error.rc index e133ae06e..58fb365e4 100644 --- a/wolfssl/sniffer_error.rc +++ b/wolfssl/sniffer_error.rc @@ -102,5 +102,9 @@ STRINGTABLE 83, "Handshake Message Split Across TLS Records" 84, "ECC Private Decode Error" 85, "ECC Public Decode Error" + + 86, "Watch callback not set" + 87, "Watch hash failed" + 88, "Watch callback failed" }