diff --git a/ctaocrypt/src/aes.c b/ctaocrypt/src/aes.c index f21e7a1a2..eabd0f4ee 100644 --- a/ctaocrypt/src/aes.c +++ b/ctaocrypt/src/aes.c @@ -1027,6 +1027,13 @@ static void AesEncrypt(Aes* aes, const byte* inBlock, byte* outBlock) CYASSL_MSG("AesEncrypt encountered improper key, set it up"); return; /* stop instead of segfaulting, set up your keys! */ } +#ifdef CYASSL_AESNI + if (haveAESNI) { + CYASSL_MSG("AesEncrypt encountered aesni keysetup, don't use direct"); + return; /* just stop now */ + } +#endif + /* * map byte array block to cipher state * and add initial round key: @@ -1165,6 +1172,13 @@ static void AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) CYASSL_MSG("AesDecrypt encountered improper key, set it up"); return; /* stop instead of segfaulting, set up your keys! */ } +#ifdef CYASSL_AESNI + if (haveAESNI) { + CYASSL_MSG("AesEncrypt encountered aesni keysetup, don't use direct"); + return; /* just stop now */ + } +#endif + /* * map byte array block to cipher state * and add initial round key: @@ -1381,6 +1395,18 @@ void AesDecryptDirect(Aes* aes, byte* out, const byte* in) #endif /* CYASSL_AES_DIRECT */ +#if defined(CYASSL_AES_DIRECT) || defined(CYASSL_AES_COUNTER) + +/* AES-CTR and AES-DIRECT need to use this for key setup, no aesni yet */ +int AesSetKeyDirect(Aes* aes, const byte* userKey, word32 keylen, + const byte* iv, int dir) +{ + return AesSetKeyLocal(aes, userKey, keylen, iv, dir); +} + +#endif /* CYASSL_AES_DIRECT || CYASSL_AES_COUNTER */ + + #ifdef CYASSL_AES_COUNTER /* Increment AES counter */ diff --git a/cyassl/ctaocrypt/aes.h b/cyassl/ctaocrypt/aes.h index 9ab625dfc..e8dc53312 100644 --- a/cyassl/ctaocrypt/aes.h +++ b/cyassl/ctaocrypt/aes.h @@ -87,7 +87,8 @@ CYASSL_API void AesCbcDecrypt(Aes* aes, byte* out, const byte* in, word32 sz); CYASSL_API void AesCtrEncrypt(Aes* aes, byte* out, const byte* in, word32 sz); CYASSL_API void AesEncryptDirect(Aes* aes, byte* out, const byte* in); CYASSL_API void AesDecryptDirect(Aes* aes, byte* out, const byte* in); - +CYASSL_API int AesSetKeyDirect(Aes* aes, const byte* key, word32 len, + const byte* iv, int dir); #ifdef HAVE_AESGCM CYASSL_API void AesGcmSetKey(Aes* aes, const byte* key, word32 len, const byte* implicitIV); diff --git a/cyassl/internal.h b/cyassl/internal.h index 887c377d9..6e4604849 100644 --- a/cyassl/internal.h +++ b/cyassl/internal.h @@ -1020,6 +1020,7 @@ typedef struct Ciphers { #ifdef BUILD_RABBIT Rabbit* rabbit; #endif + byte setup; /* have we set it up flag for detection */ } Ciphers; diff --git a/examples/client/client.c b/examples/client/client.c index 5c1a3290d..d2e9c7ce4 100644 --- a/examples/client/client.c +++ b/examples/client/client.c @@ -51,6 +51,7 @@ int timeout_count = CyaSSL_dtls_get_current_timeout(ssl) * 10; while (ret != SSL_SUCCESS && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)) { + (void)timeout_count; if (error == SSL_ERROR_WANT_READ) printf("... client would read block\n"); else diff --git a/src/internal.c b/src/internal.c index e89f23f04..d3cd51e00 100644 --- a/src/internal.c +++ b/src/internal.c @@ -465,6 +465,8 @@ void InitCiphers(CYASSL* ssl) ssl->encrypt.rabbit = NULL; ssl->decrypt.rabbit = NULL; #endif + ssl->encrypt.setup = 0; + ssl->decrypt.setup = 0; } @@ -2672,8 +2674,13 @@ static INLINE word32 GetSEQIncrement(CYASSL* ssl, int verify) } -static INLINE void Encrypt(CYASSL* ssl, byte* out, const byte* input, word32 sz) +static INLINE int Encrypt(CYASSL* ssl, byte* out, const byte* input, word32 sz) { + if (ssl->encrypt.setup == 0) { + CYASSL_MSG("Encrypt ciphers not setup"); + return ENCRYPT_ERROR; + } + switch (ssl->specs.bulk_cipher_algorithm) { #ifdef BUILD_ARC4 case rc4: @@ -2745,13 +2752,21 @@ static INLINE void Encrypt(CYASSL* ssl, byte* out, const byte* input, word32 sz) default: CYASSL_MSG("CyaSSL Encrypt programming error"); + return ENCRYPT_ERROR; } + + return 0; } static INLINE int Decrypt(CYASSL* ssl, byte* plain, const byte* input, word32 sz) { + if (ssl->decrypt.setup == 0) { + CYASSL_MSG("Decrypt ciphers not setup"); + return DECRYPT_ERROR; + } + switch (ssl->specs.bulk_cipher_algorithm) { #ifdef BUILD_ARC4 case rc4: @@ -2815,6 +2830,7 @@ static INLINE int Decrypt(CYASSL* ssl, byte* plain, const byte* input, default: CYASSL_MSG("CyaSSL Decrypt programming error"); + return DECRYPT_ERROR; } return 0; } @@ -3498,6 +3514,7 @@ static int BuildMessage(CYASSL* ssl, byte* output, const byte* input, int inSz, word32 headerSz = RECORD_HEADER_SZ; word16 size; byte iv[AES_BLOCK_SIZE]; /* max size */ + int ret = 0; #ifdef CYASSL_DTLS if (ssl->options.dtls) { @@ -3541,7 +3558,6 @@ static int BuildMessage(CYASSL* ssl, byte* output, const byte* input, int inSz, if (type == handshake) { #ifdef CYASSL_DTLS if (ssl->options.dtls) { - int ret; if ((ret = DtlsPoolSave(ssl, output, headerSz+inSz)) != 0) return ret; } @@ -3557,7 +3573,8 @@ static int BuildMessage(CYASSL* ssl, byte* output, const byte* input, int inSz, for (i = 0; i <= pad; i++) output[idx++] = (byte)pad; /* pad byte gets pad value too */ - Encrypt(ssl, output + headerSz, output + headerSz, size); + if ( (ret = Encrypt(ssl, output + headerSz, output + headerSz, size)) != 0) + return ret; return sz; } diff --git a/src/io.c b/src/io.c index fde60ae1b..85a03c67a 100644 --- a/src/io.c +++ b/src/io.c @@ -167,7 +167,7 @@ int EmbedReceive(CYASSL *ssl, char *buf, int sz, void *ctx) CYASSL_MSG("Embed Receive error"); if (err == SOCKET_EWOULDBLOCK || err == SOCKET_EAGAIN) { - if (CyaSSL_get_using_nonblock(ssl)) { + if (!CyaSSL_dtls(ssl) || CyaSSL_get_using_nonblock(ssl)) { CYASSL_MSG(" Would block"); return IO_ERR_WANT_READ; } diff --git a/src/keys.c b/src/keys.c index 889739d8d..c40fb5f51 100644 --- a/src/keys.c +++ b/src/keys.c @@ -937,6 +937,8 @@ static int SetKeys(Ciphers* enc, Ciphers* dec, Keys* keys, CipherSpecs* specs, Arc4SetKey(enc->arc4, keys->server_write_key, sz); Arc4SetKey(dec->arc4, keys->client_write_key, sz); } + enc->setup = 1; + dec->setup = 1; } #endif @@ -960,6 +962,8 @@ static int SetKeys(Ciphers* enc, Ciphers* dec, Keys* keys, CipherSpecs* specs, Hc128_SetKey(dec->hc128, keys->client_write_key, keys->client_write_IV); } + enc->setup = 1; + dec->setup = 1; } #endif @@ -983,6 +987,8 @@ static int SetKeys(Ciphers* enc, Ciphers* dec, Keys* keys, CipherSpecs* specs, RabbitSetKey(dec->rabbit, keys->client_write_key, keys->client_write_IV); } + enc->setup = 1; + dec->setup = 1; } #endif @@ -1006,6 +1012,8 @@ static int SetKeys(Ciphers* enc, Ciphers* dec, Keys* keys, CipherSpecs* specs, Des3_SetKey(dec->des3, keys->client_write_key, keys->client_write_IV, DES_DECRYPTION); } + enc->setup = 1; + dec->setup = 1; } #endif @@ -1033,6 +1041,8 @@ static int SetKeys(Ciphers* enc, Ciphers* dec, Keys* keys, CipherSpecs* specs, specs->key_size, keys->client_write_IV, AES_DECRYPTION); } + enc->setup = 1; + dec->setup = 1; } #endif @@ -1062,6 +1072,8 @@ static int SetKeys(Ciphers* enc, Ciphers* dec, Keys* keys, CipherSpecs* specs, AesGcmSetKey(dec->aes, keys->client_write_key, specs->key_size, keys->client_write_IV); } + enc->setup = 1; + dec->setup = 1; } #endif diff --git a/src/sniffer.c b/src/sniffer.c index 4ae300c84..b385c28aa 100644 --- a/src/sniffer.c +++ b/src/sniffer.c @@ -845,7 +845,7 @@ static SnifferSession* GetSnifferSession(IpInfo* ipInfo, TcpInfo* tcpInfo) SnifferSession* session; word32 row = SessionHash(ipInfo, tcpInfo); - assert(row >= 0 && row <= HASH_SIZE); + assert(row <= HASH_SIZE); LockMutex(&SessionMutex); @@ -1585,7 +1585,7 @@ static void RemoveSession(SnifferSession* session, IpInfo* ipInfo, else haveLock = 1; - assert(row >= 0 && row <= HASH_SIZE); + assert(row <= HASH_SIZE); Trace(REMOVE_SESSION_STR); if (!haveLock) @@ -1663,12 +1663,16 @@ static SnifferSession* CreateSession(IpInfo* ipInfo, TcpInfo* tcpInfo, } session->sslServer = SSL_new(session->context->ctx); + if (session->sslServer == NULL) { + SetError(BAD_NEW_SSL_STR, error, session, FATAL_ERROR_STATE); + free(session); + return 0; + } session->sslClient = SSL_new(session->context->ctx); if (session->sslClient == NULL) { - if (session->sslServer) { - SSL_free(session->sslClient); - session->sslClient = 0; - } + SSL_free(session->sslServer); + session->sslServer = 0; + SetError(BAD_NEW_SSL_STR, error, session, FATAL_ERROR_STATE); free(session); return 0;