From 42a0f3500fbf42bb021bd3a831caf62b15f67f2b Mon Sep 17 00:00:00 2001 From: John Safranek Date: Sun, 5 May 2013 20:55:38 -0700 Subject: [PATCH] Update AES-GCM and AES-CCM to use AES-NI 1. Added the assembly functions to do AES-ECB. 2. Updated AesEncrypt and AesDecrypt to use the assembly functions if available. 3. Modified the AES-GCM and AES-CCM key setup functions to use the the AES-NI key setup if availble. 4. Added tests for the AES-ECB encrypt and decrypt. 5. Only include stdio.h for AES when DEBUG_AESNI is enabled 6. If using local key setup, skip using AES-NI for basic Encrypt and Decrypt. --- ctaocrypt/src/aes.c | 84 +++++++++-- ctaocrypt/src/aes_asm.s | 320 ++++++++++++++++++++++++++++++++++++++++ ctaocrypt/test/test.c | 36 +++++ 3 files changed, 430 insertions(+), 10 deletions(-) diff --git a/ctaocrypt/src/aes.c b/ctaocrypt/src/aes.c index 26cb53c53..cc0e815e8 100644 --- a/ctaocrypt/src/aes.c +++ b/ctaocrypt/src/aes.c @@ -35,6 +35,9 @@ #else #include #endif +#ifdef DEBUG_AESNI + #include +#endif #ifdef _MSC_VER @@ -1148,6 +1151,15 @@ void AES_CBC_decrypt(const unsigned char* in, unsigned char* out, const unsigned char* KS, int nr) asm ("AES_CBC_decrypt"); +void AES_ECB_encrypt(const unsigned char* in, unsigned char* out, + unsigned long length, const unsigned char* KS, int nr) + asm ("AES_ECB_encrypt"); + + +void AES_ECB_decrypt(const unsigned char* in, unsigned char* out, + unsigned long length, const unsigned char* KS, int nr) + asm ("AES_ECB_decrypt"); + void AES_128_Key_Expansion(const unsigned char* userkey, unsigned char* key_schedule) asm ("AES_128_Key_Expansion"); @@ -1409,10 +1421,45 @@ static void AesEncrypt(Aes* aes, const byte* inBlock, byte* outBlock) return; /* stop instead of segfaulting, set up your keys! */ } #ifdef CYASSL_AESNI - if (aes->use_aesni) { - CYASSL_MSG("AesEncrypt encountered aesni keysetup, don't use direct"); - return; /* just stop now */ - } + if (haveAESNI && aes->use_aesni) { + #ifdef DEBUG_AESNI + printf("about to aes encrypt\n"); + printf("in = %p\n", inBlock); + printf("out = %p\n", outBlock); + printf("aes->key = %p\n", aes->key); + printf("aes->rounds = %d\n", aes->rounds); + printf("sz = %d\n", AES_BLOCK_SIZE); + #endif + + /* check alignment, decrypt doesn't need alignment */ + if ((word)inBlock % 16) { + #ifndef NO_CYASSL_ALLOC_ALIGN + byte* tmp = (byte*)XMALLOC(AES_BLOCK_SIZE, NULL, + DYNAMIC_TYPE_TMP_BUFFER); + if (tmp == NULL) return; + + XMEMCPY(tmp, inBlock, AES_BLOCK_SIZE); + AES_ECB_encrypt(tmp, tmp, AES_BLOCK_SIZE, (byte*)aes->key, + aes->rounds); + XMEMCPY(outBlock, tmp, AES_BLOCK_SIZE); + XFREE(tmp, NULL, DYNAMIC_TYPE_TMP_BUFFER); + return; + #else + CYASSL_MSG("AES-ECB encrypt with bad alignment"); + return; + #endif + } + + AES_ECB_encrypt(inBlock, outBlock, AES_BLOCK_SIZE, (byte*)aes->key, + aes->rounds); + + return; + } + else { + #ifdef DEBUG_AESNI + printf("Skipping AES-NI\n"); + #endif + } #endif /* @@ -1554,10 +1601,27 @@ static void AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock) return; /* stop instead of segfaulting, set up your keys! */ } #ifdef CYASSL_AESNI - if (aes->use_aesni) { - CYASSL_MSG("AesEncrypt encountered aesni keysetup, don't use direct"); - return; /* just stop now */ - } + if (haveAESNI && aes->use_aesni) { + #ifdef DEBUG_AESNI + printf("about to aes decrypt\n"); + printf("in = %p\n", inBlock); + printf("out = %p\n", outBlock); + printf("aes->key = %p\n", aes->key); + printf("aes->rounds = %d\n", aes->rounds); + printf("sz = %d\n", AES_BLOCK_SIZE); + #endif + + /* if input and output same will overwrite input iv */ + XMEMCPY(aes->tmp, inBlock, AES_BLOCK_SIZE); + AES_ECB_decrypt(inBlock, outBlock, AES_BLOCK_SIZE, (byte*)aes->key, + aes->rounds); + return; + } + else { + #ifdef DEBUG_AESNI + printf("Skipping AES-NI\n"); + #endif + } #endif /* @@ -1969,7 +2033,7 @@ void AesGcmSetKey(Aes* aes, const byte* key, word32 len) return; XMEMSET(iv, 0, AES_BLOCK_SIZE); - AesSetKeyLocal(aes, key, len, iv, AES_ENCRYPTION); + AesSetKey(aes, key, len, iv, AES_ENCRYPTION); AesEncrypt(aes, iv, aes->H); #ifdef GCM_TABLE @@ -2584,7 +2648,7 @@ void AesCcmSetKey(Aes* aes, const byte* key, word32 keySz) return; XMEMSET(nonce, 0, sizeof(nonce)); - AesSetKeyLocal(aes, key, keySz, nonce, AES_ENCRYPTION); + AesSetKey(aes, key, keySz, nonce, AES_ENCRYPTION); } diff --git a/ctaocrypt/src/aes_asm.s b/ctaocrypt/src/aes_asm.s index 829b2fe14..a1df56b70 100755 --- a/ctaocrypt/src/aes_asm.s +++ b/ctaocrypt/src/aes_asm.s @@ -84,6 +84,7 @@ ret + /* AES_CBC_decrypt (const unsigned char *in, unsigned char *out, @@ -262,6 +263,325 @@ DEND_4: ret +/* +AES_ECB_encrypt (const unsigned char *in, + unsigned char *out, + unsigned long length, + const unsigned char *KS, + int nr) +*/ +.globl AES_ECB_encrypt +AES_ECB_encrypt: +# parameter 1: %rdi +# parameter 2: %rsi +# parameter 3: %rdx +# parameter 4: %rcx +# parameter 5: %r8d + movq %rdx, %r10 + shrq $4, %rdx + shlq $60, %r10 + je EECB_NO_PARTS_4 + addq $1, %rdx +EECB_NO_PARTS_4: + movq %rdx, %r10 + shlq $62, %r10 + shrq $62, %r10 + shrq $2, %rdx + je EECB_REMAINDER_4 + subq $64, %rsi +EECB_LOOP_4: + movdqu (%rdi), %xmm1 + movdqu 16(%rdi), %xmm2 + movdqu 32(%rdi), %xmm3 + movdqu 48(%rdi), %xmm4 + movdqa (%rcx), %xmm9 + movdqa 16(%rcx), %xmm10 + movdqa 32(%rcx), %xmm11 + movdqa 48(%rcx), %xmm12 + pxor %xmm9, %xmm1 + pxor %xmm9, %xmm2 + pxor %xmm9, %xmm3 + pxor %xmm9, %xmm4 + aesenc %xmm10, %xmm1 + aesenc %xmm10, %xmm2 + aesenc %xmm10, %xmm3 + aesenc %xmm10, %xmm4 + aesenc %xmm11, %xmm1 + aesenc %xmm11, %xmm2 + aesenc %xmm11, %xmm3 + aesenc %xmm11, %xmm4 + aesenc %xmm12, %xmm1 + aesenc %xmm12, %xmm2 + aesenc %xmm12, %xmm3 + aesenc %xmm12, %xmm4 + movdqa 64(%rcx), %xmm9 + movdqa 80(%rcx), %xmm10 + movdqa 96(%rcx), %xmm11 + movdqa 112(%rcx), %xmm12 + aesenc %xmm9, %xmm1 + aesenc %xmm9, %xmm2 + aesenc %xmm9, %xmm3 + aesenc %xmm9, %xmm4 + aesenc %xmm10, %xmm1 + aesenc %xmm10, %xmm2 + aesenc %xmm10, %xmm3 + aesenc %xmm10, %xmm4 + aesenc %xmm11, %xmm1 + aesenc %xmm11, %xmm2 + aesenc %xmm11, %xmm3 + aesenc %xmm11, %xmm4 + aesenc %xmm12, %xmm1 + aesenc %xmm12, %xmm2 + aesenc %xmm12, %xmm3 + aesenc %xmm12, %xmm4 + movdqa 128(%rcx), %xmm9 + movdqa 144(%rcx), %xmm10 + movdqa 160(%rcx), %xmm11 + cmpl $12, %r8d + aesenc %xmm9, %xmm1 + aesenc %xmm9, %xmm2 + aesenc %xmm9, %xmm3 + aesenc %xmm9, %xmm4 + aesenc %xmm10, %xmm1 + aesenc %xmm10, %xmm2 + aesenc %xmm10, %xmm3 + aesenc %xmm10, %xmm4 + jb EECB_LAST_4 + movdqa 160(%rcx), %xmm9 + movdqa 176(%rcx), %xmm10 + movdqa 192(%rcx), %xmm11 + cmpl $14, %r8d + aesenc %xmm9, %xmm1 + aesenc %xmm9, %xmm2 + aesenc %xmm9, %xmm3 + aesenc %xmm9, %xmm4 + aesenc %xmm10, %xmm1 + aesenc %xmm10, %xmm2 + aesenc %xmm10, %xmm3 + aesenc %xmm10, %xmm4 + jb EECB_LAST_4 + movdqa 192(%rcx), %xmm9 + movdqa 208(%rcx), %xmm10 + movdqa 224(%rcx), %xmm11 + aesenc %xmm9, %xmm1 + aesenc %xmm9, %xmm2 + aesenc %xmm9, %xmm3 + aesenc %xmm9, %xmm4 + aesenc %xmm10, %xmm1 + aesenc %xmm10, %xmm2 + aesenc %xmm10, %xmm3 + aesenc %xmm10, %xmm4 +EECB_LAST_4: + addq $64, %rdi + addq $64, %rsi + decq %rdx + aesenclast %xmm11, %xmm1 + aesenclast %xmm11, %xmm2 + aesenclast %xmm11, %xmm3 + aesenclast %xmm11, %xmm4 + movdqu %xmm1, (%rsi) + movdqu %xmm2, 16(%rsi) + movdqu %xmm3, 32(%rsi) + movdqu %xmm4, 48(%rsi) + jne EECB_LOOP_4 + addq $64, %rsi +EECB_REMAINDER_4: + cmpq $0, %r10 + je EECB_END_4 +EECB_LOOP_4_2: + movdqu (%rdi), %xmm1 + addq $16, %rdi + pxor (%rcx), %xmm1 + movdqu 160(%rcx), %xmm2 + aesenc 16(%rcx), %xmm1 + aesenc 32(%rcx), %xmm1 + aesenc 48(%rcx), %xmm1 + aesenc 64(%rcx), %xmm1 + aesenc 80(%rcx), %xmm1 + aesenc 96(%rcx), %xmm1 + aesenc 112(%rcx), %xmm1 + aesenc 128(%rcx), %xmm1 + aesenc 144(%rcx), %xmm1 + cmpl $12, %r8d + jb EECB_LAST_4_2 + movdqu 192(%rcx), %xmm2 + aesenc 160(%rcx), %xmm1 + aesenc 176(%rcx), %xmm1 + cmpl $14, %r8d + jb EECB_LAST_4_2 + movdqu 224(%rcx), %xmm2 + aesenc 192(%rcx), %xmm1 + aesenc 208(%rcx), %xmm1 +EECB_LAST_4_2: + aesenclast %xmm2, %xmm1 + movdqu %xmm1, (%rsi) + addq $16, %rsi + decq %r10 + jne EECB_LOOP_4_2 +EECB_END_4: + ret + + +/* +AES_ECB_decrypt (const unsigned char *in, + unsigned char *out, + unsigned long length, + const unsigned char *KS, + int nr) +*/ +.globl AES_ECB_decrypt +AES_ECB_decrypt: +# parameter 1: %rdi +# parameter 2: %rsi +# parameter 3: %rdx +# parameter 4: %rcx +# parameter 5: %r8d + + movq %rdx, %r10 + shrq $4, %rdx + shlq $60, %r10 + je DECB_NO_PARTS_4 + addq $1, %rdx +DECB_NO_PARTS_4: + movq %rdx, %r10 + shlq $62, %r10 + shrq $62, %r10 + shrq $2, %rdx + je DECB_REMAINDER_4 + subq $64, %rsi +DECB_LOOP_4: + movdqu (%rdi), %xmm1 + movdqu 16(%rdi), %xmm2 + movdqu 32(%rdi), %xmm3 + movdqu 48(%rdi), %xmm4 + movdqa (%rcx), %xmm9 + movdqa 16(%rcx), %xmm10 + movdqa 32(%rcx), %xmm11 + movdqa 48(%rcx), %xmm12 + pxor %xmm9, %xmm1 + pxor %xmm9, %xmm2 + pxor %xmm9, %xmm3 + pxor %xmm9, %xmm4 + aesdec %xmm10, %xmm1 + aesdec %xmm10, %xmm2 + aesdec %xmm10, %xmm3 + aesdec %xmm10, %xmm4 + aesdec %xmm11, %xmm1 + aesdec %xmm11, %xmm2 + aesdec %xmm11, %xmm3 + aesdec %xmm11, %xmm4 + aesdec %xmm12, %xmm1 + aesdec %xmm12, %xmm2 + aesdec %xmm12, %xmm3 + aesdec %xmm12, %xmm4 + movdqa 64(%rcx), %xmm9 + movdqa 80(%rcx), %xmm10 + movdqa 96(%rcx), %xmm11 + movdqa 112(%rcx), %xmm12 + aesdec %xmm9, %xmm1 + aesdec %xmm9, %xmm2 + aesdec %xmm9, %xmm3 + aesdec %xmm9, %xmm4 + aesdec %xmm10, %xmm1 + aesdec %xmm10, %xmm2 + aesdec %xmm10, %xmm3 + aesdec %xmm10, %xmm4 + aesdec %xmm11, %xmm1 + aesdec %xmm11, %xmm2 + aesdec %xmm11, %xmm3 + aesdec %xmm11, %xmm4 + aesdec %xmm12, %xmm1 + aesdec %xmm12, %xmm2 + aesdec %xmm12, %xmm3 + aesdec %xmm12, %xmm4 + movdqa 128(%rcx), %xmm9 + movdqa 144(%rcx), %xmm10 + movdqa 160(%rcx), %xmm11 + cmpl $12, %r8d + aesdec %xmm9, %xmm1 + aesdec %xmm9, %xmm2 + aesdec %xmm9, %xmm3 + aesdec %xmm9, %xmm4 + aesdec %xmm10, %xmm1 + aesdec %xmm10, %xmm2 + aesdec %xmm10, %xmm3 + aesdec %xmm10, %xmm4 + jb DECB_LAST_4 + movdqa 160(%rcx), %xmm9 + movdqa 176(%rcx), %xmm10 + movdqa 192(%rcx), %xmm11 + cmpl $14, %r8d + aesdec %xmm9, %xmm1 + aesdec %xmm9, %xmm2 + aesdec %xmm9, %xmm3 + aesdec %xmm9, %xmm4 + aesdec %xmm10, %xmm1 + aesdec %xmm10, %xmm2 + aesdec %xmm10, %xmm3 + aesdec %xmm10, %xmm4 + jb DECB_LAST_4 + movdqa 192(%rcx), %xmm9 + movdqa 208(%rcx), %xmm10 + movdqa 224(%rcx), %xmm11 + aesdec %xmm9, %xmm1 + aesdec %xmm9, %xmm2 + aesdec %xmm9, %xmm3 + aesdec %xmm9, %xmm4 + aesdec %xmm10, %xmm1 + aesdec %xmm10, %xmm2 + aesdec %xmm10, %xmm3 + aesdec %xmm10, %xmm4 +DECB_LAST_4: + addq $64, %rdi + addq $64, %rsi + decq %rdx + aesdeclast %xmm11, %xmm1 + aesdeclast %xmm11, %xmm2 + aesdeclast %xmm11, %xmm3 + aesdeclast %xmm11, %xmm4 + movdqu %xmm1, (%rsi) + movdqu %xmm2, 16(%rsi) + movdqu %xmm3, 32(%rsi) + movdqu %xmm4, 48(%rsi) + jne DECB_LOOP_4 + addq $64, %rsi +DECB_REMAINDER_4: + cmpq $0, %r10 + je DECB_END_4 +DECB_LOOP_4_2: + movdqu (%rdi), %xmm1 + addq $16, %rdi + pxor (%rcx), %xmm1 + movdqu 160(%rcx), %xmm2 + cmpl $12, %r8d + aesdec 16(%rcx), %xmm1 + aesdec 32(%rcx), %xmm1 + aesdec 48(%rcx), %xmm1 + aesdec 64(%rcx), %xmm1 + aesdec 80(%rcx), %xmm1 + aesdec 96(%rcx), %xmm1 + aesdec 112(%rcx), %xmm1 + aesdec 128(%rcx), %xmm1 + aesdec 144(%rcx), %xmm1 + jb DECB_LAST_4_2 + cmpl $14, %r8d + movdqu 192(%rcx), %xmm2 + aesdec 160(%rcx), %xmm1 + aesdec 176(%rcx), %xmm1 + jb DECB_LAST_4_2 + movdqu 224(%rcx), %xmm2 + aesdec 192(%rcx), %xmm1 + aesdec 208(%rcx), %xmm1 +DECB_LAST_4_2: + aesdeclast %xmm2, %xmm1 + movdqu %xmm1, (%rsi) + addq $16, %rsi + decq %r10 + jne DECB_LOOP_4_2 +DECB_END_4: + ret + + /* diff --git a/ctaocrypt/test/test.c b/ctaocrypt/test/test.c index 4d832524c..5137b0d91 100644 --- a/ctaocrypt/test/test.c +++ b/ctaocrypt/test/test.c @@ -1800,6 +1800,42 @@ int aes_test(void) } #endif /* CYASSL_AES_COUNTER */ +#if defined(CYASSL_AESNI) && defined(CYASSL_AES_DIRECT) + { + const byte niPlain[] = + { + 0x6b,0xc1,0xbe,0xe2,0x2e,0x40,0x9f,0x96, + 0xe9,0x3d,0x7e,0x11,0x73,0x93,0x17,0x2a + }; + + const byte niCipher[] = + { + 0xf3,0xee,0xd1,0xbd,0xb5,0xd2,0xa0,0x3c, + 0x06,0x4b,0x5a,0x7e,0x3d,0xb1,0x81,0xf8 + }; + + const byte niKey[] = + { + 0x60,0x3d,0xeb,0x10,0x15,0xca,0x71,0xbe, + 0x2b,0x73,0xae,0xf0,0x85,0x7d,0x77,0x81, + 0x1f,0x35,0x2c,0x07,0x3b,0x61,0x08,0xd7, + 0x2d,0x98,0x10,0xa3,0x09,0x14,0xdf,0xf4 + }; + + XMEMSET(cipher, 0, AES_BLOCK_SIZE); + AesSetKey(&enc, niKey, sizeof(niKey), cipher, AES_ENCRYPTION); + AesEncryptDirect(&enc, cipher, niPlain); + if (XMEMCMP(cipher, niCipher, AES_BLOCK_SIZE) != 0) + return -20006; + + XMEMSET(plain, 0, AES_BLOCK_SIZE); + AesSetKey(&dec, niKey, sizeof(niKey), plain, AES_DECRYPTION); + AesDecryptDirect(&dec, plain, niCipher); + if (XMEMCMP(plain, niPlain, AES_BLOCK_SIZE) != 0) + return -20007; + } +#endif /* CYASSL_AESNI && CYASSL_AES_DIRECT */ + return 0; }