diff --git a/sslSniffer/sslSnifferTest/snifftest.c b/sslSniffer/sslSnifferTest/snifftest.c index cebbb2682..af6bbe56d 100644 --- a/sslSniffer/sslSnifferTest/snifftest.c +++ b/sslSniffer/sslSnifferTest/snifftest.c @@ -375,8 +375,9 @@ int main(int argc, char** argv) int frame = ETHER_IF_FRAME_LEN; char err[PCAP_ERRBUF_SIZE]; char filter[32]; - const char *keyFiles = NULL; + const char *keyFilesSrc = NULL; char keyFilesBuf[MAX_FILENAME_SZ]; + char keyFilesUser[MAX_FILENAME_SZ]; const char *server = NULL; const char *sniName = NULL; struct bpf_program fp; @@ -493,21 +494,23 @@ int main(int argc, char** argv) /* optionally enter the private key to use */ #if defined(WOLFSSL_STATIC_EPHEMERAL) && defined(DEFAULT_SERVER_EPH_KEY) - keyFiles = DEFAULT_SERVER_EPH_KEY; + keyFilesSrc = DEFAULT_SERVER_EPH_KEY; #else - keyFiles = DEFAULT_SERVER_KEY; + keyFilesSrc = DEFAULT_SERVER_KEY; #endif - printf("Enter the server key [default: %s]: ", keyFiles); + printf("Enter the server key [default: %s]: ", keyFilesSrc); XMEMSET(keyFilesBuf, 0, sizeof(keyFilesBuf)); - if (XFGETS(keyFilesBuf, sizeof(keyFilesBuf), stdin)) { - if (keyFilesBuf[0] != '\r' && keyFilesBuf[0] != '\n') { - keyFiles = keyFilesBuf; + XMEMSET(keyFilesUser, 0, sizeof(keyFilesUser)); + if (XFGETS(keyFilesUser, sizeof(keyFilesUser), stdin)) { + word32 strSz; + if (keyFilesUser[0] != '\r' && keyFilesUser[0] != '\n') { + keyFilesSrc = keyFilesUser; } + strSz = (word32)XSTRLEN(keyFilesUser); + if (keyFilesUser[strSz-1] == '\n') + keyFilesUser[strSz-1] = '\0'; } - if (keyFiles != keyFilesBuf) { - XSTRNCPY(keyFilesBuf, keyFiles, sizeof(keyFilesBuf)); - keyFiles = keyFilesBuf; - } + XSTRNCPY(keyFilesBuf, keyFilesSrc, sizeof(keyFilesBuf)); /* optionally enter a named key (SNI) */ #if !defined(WOLFSSL_SNIFFER_WATCH) && defined(HAVE_SNI) @@ -533,7 +536,8 @@ int main(int argc, char** argv) } if (server) { - ret = load_key(sniName, server, port, keyFiles, NULL, err); + XSTRNCPY(keyFilesBuf, keyFilesSrc, sizeof(keyFilesBuf)); + ret = load_key(sniName, server, port, keyFilesBuf, NULL, err); if (ret != 0) { exit(EXIT_FAILURE); } @@ -553,7 +557,7 @@ int main(int argc, char** argv) /* defaults for server and port */ port = 443; server = "127.0.0.1"; - keyFiles = argv[2]; + keyFilesSrc = argv[2]; if (argc >= 4) server = argv[3]; @@ -564,7 +568,7 @@ int main(int argc, char** argv) if (argc >= 6) passwd = argv[5]; - ret = load_key(NULL, server, port, keyFiles, passwd, err); + ret = load_key(NULL, server, port, keyFilesSrc, passwd, err); if (ret != 0) { exit(EXIT_FAILURE); }