From 957fc7460c46836cb1c4d8079fbbb74d1ba314f9 Mon Sep 17 00:00:00 2001 From: Daniel Pouzzner Date: Sat, 27 Jan 2024 23:16:02 -0600 Subject: [PATCH] linuxkm/lkcapi_glue.c: refactor AES-CBC, AES-CFB, and AES-GCM glue around struct km_AesCtx with separate aes_encrypt and aes_decrypt Aes pointers, and no cached key, to avoid AesSetKey operations at encrypt/decrypt time. --- linuxkm/lkcapi_glue.c | 172 +++++++++++++++++++++++++++--------------- 1 file changed, 111 insertions(+), 61 deletions(-) diff --git a/linuxkm/lkcapi_glue.c b/linuxkm/lkcapi_glue.c index 2ec81eedb..6cf84d98c 100644 --- a/linuxkm/lkcapi_glue.c +++ b/linuxkm/lkcapi_glue.c @@ -96,35 +96,57 @@ static int linuxkm_test_aesxts(void); #include struct km_AesCtx { - Aes *aes; /* must be pointer to control alignment, needed for AESNI. */ - u8 key[AES_MAX_KEY_SIZE / 8]; - unsigned int keylen; + Aes *aes_encrypt; /* must be pointer to control alignment, needed for AESNI. */ + Aes *aes_decrypt; /* same. */ }; -static inline void km_ForceZero(struct km_AesCtx * ctx) -{ - memzero_explicit(ctx->key, sizeof(ctx->key)); - ctx->keylen = 0; -} - #if defined(LINUXKM_LKCAPI_REGISTER_ALL) || \ defined(LINUXKM_LKCAPI_REGISTER_AESCBC) || \ defined(LINUXKM_LKCAPI_REGISTER_AESCFB) || \ defined(LINUXKM_LKCAPI_REGISTER_AESGCM) -static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name) +static void km_AesExitCommon(struct km_AesCtx * ctx); + +static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name, int need_decryption) { int err; - ctx->aes = (Aes *)malloc(sizeof(*ctx->aes)); + ctx->aes_encrypt = (Aes *)malloc(sizeof(*ctx->aes_encrypt)); - if (! ctx->aes) + if (! ctx->aes_encrypt) { + pr_err("error: km_AesInitCommon %s failed: %d\n", name, MEMORY_E); return MEMORY_E; + } - err = wc_AesInit(ctx->aes, NULL, INVALID_DEVID); + err = wc_AesInit(ctx->aes_encrypt, NULL, INVALID_DEVID); if (unlikely(err)) { pr_err("error: km_AesInitCommon %s failed: %d\n", name, err); + free(ctx->aes_encrypt); + ctx->aes_encrypt = NULL; + return err; + } + + if (! need_decryption) { + ctx->aes_decrypt = NULL; + return 0; + } + + ctx->aes_decrypt = (Aes *)malloc(sizeof(*ctx->aes_decrypt)); + + if (! ctx->aes_encrypt) { + pr_err("error: km_AesInitCommon %s failed: %d\n", name, MEMORY_E); + km_AesExitCommon(ctx); + return MEMORY_E; + } + + err = wc_AesInit(ctx->aes_decrypt, NULL, INVALID_DEVID); + + if (unlikely(err)) { + pr_err("error: km_AesInitCommon %s failed: %d\n", name, err); + free(ctx->aes_decrypt); + ctx->aes_decrypt = NULL; + km_AesExitCommon(ctx); return err; } @@ -133,10 +155,16 @@ static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name) static void km_AesExitCommon(struct km_AesCtx * ctx) { - wc_AesFree(ctx->aes); - free(ctx->aes); - ctx->aes = NULL; - km_ForceZero(ctx); + if (ctx->aes_encrypt) { + wc_AesFree(ctx->aes_encrypt); + free(ctx->aes_encrypt); + ctx->aes_encrypt = NULL; + } + if (ctx->aes_decrypt) { + wc_AesFree(ctx->aes_decrypt); + free(ctx->aes_decrypt); + ctx->aes_decrypt = NULL; + } } static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key, @@ -144,15 +172,21 @@ static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key, { int err; - err = wc_AesSetKey(ctx->aes, in_key, key_len, NULL, 0); + err = wc_AesSetKey(ctx->aes_encrypt, in_key, key_len, NULL, AES_ENCRYPTION); if (unlikely(err)) { pr_err("error: km_AesSetKeyCommon %s failed: %d\n", name, err); return err; } - XMEMCPY(ctx->key, in_key, key_len); - ctx->keylen = key_len; + if (ctx->aes_decrypt) { + err = wc_AesSetKey(ctx->aes_decrypt, in_key, key_len, NULL, AES_DECRYPTION); + + if (unlikely(err)) { + pr_err("error: km_AesSetKeyCommon %s failed: %d\n", name, err); + return err; + } + } return 0; } @@ -161,25 +195,12 @@ static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key, defined(LINUXKM_LKCAPI_REGISTER_AESCBC) || \ defined(LINUXKM_LKCAPI_REGISTER_AESCFB) -static int km_AesInit(struct crypto_skcipher *tfm) -{ - struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm); - return km_AesInitCommon(ctx, WOLFKM_AESCBC_DRIVER); -} - static void km_AesExit(struct crypto_skcipher *tfm) { struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm); km_AesExitCommon(ctx); } -static int km_AesSetKey(struct crypto_skcipher *tfm, const u8 *in_key, - unsigned int key_len) -{ - struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm); - return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCBC_DRIVER); -} - #endif /* LINUXKM_LKCAPI_REGISTER_ALL || * LINUXKM_LKCAPI_REGISTER_AESCBC || * LINUXKM_LKCAPI_REGISTER_AESCFB @@ -192,6 +213,19 @@ static int km_AesSetKey(struct crypto_skcipher *tfm, const u8 *in_key, #if defined(HAVE_AES_CBC) && \ (defined(LINUXKM_LKCAPI_REGISTER_ALL) || defined(LINUXKM_LKCAPI_REGISTER_AESCBC)) +static int km_AesCbcInit(struct crypto_skcipher *tfm) +{ + struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm); + return km_AesInitCommon(ctx, WOLFKM_AESCBC_DRIVER, 1); +} + +static int km_AesCbcSetKey(struct crypto_skcipher *tfm, const u8 *in_key, + unsigned int key_len) +{ + struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm); + return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCBC_DRIVER); +} + static int km_AesCbcEncrypt(struct skcipher_request *req) { struct crypto_skcipher * tfm = NULL; @@ -206,15 +240,14 @@ static int km_AesCbcEncrypt(struct skcipher_request *req) err = skcipher_walk_virt(&walk, req, false); while ((nbytes = walk.nbytes)) { - err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv, - AES_ENCRYPTION); + err = wc_AesSetIV(ctx->aes_encrypt, walk.iv); if (unlikely(err)) { - pr_err("wc_AesSetKey failed: %d\n", err); + pr_err("wc_AesSetIV failed: %d\n", err); return err; } - err = wc_AesCbcEncrypt(ctx->aes, walk.dst.virt.addr, + err = wc_AesCbcEncrypt(ctx->aes_encrypt, walk.dst.virt.addr, walk.src.virt.addr, nbytes); if (unlikely(err)) { @@ -242,15 +275,14 @@ static int km_AesCbcDecrypt(struct skcipher_request *req) err = skcipher_walk_virt(&walk, req, false); while ((nbytes = walk.nbytes)) { - err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv, - AES_DECRYPTION); + err = wc_AesSetIV(ctx->aes_decrypt, walk.iv); if (unlikely(err)) { pr_err("wc_AesSetKey failed"); return err; } - err = wc_AesCbcDecrypt(ctx->aes, walk.dst.virt.addr, + err = wc_AesCbcDecrypt(ctx->aes_decrypt, walk.dst.virt.addr, walk.src.virt.addr, nbytes); if (unlikely(err)) { @@ -271,12 +303,12 @@ static struct skcipher_alg cbcAesAlg = { .base.cra_blocksize = AES_BLOCK_SIZE, .base.cra_ctxsize = sizeof(struct km_AesCtx), .base.cra_module = THIS_MODULE, - .init = km_AesInit, + .init = km_AesCbcInit, .exit = km_AesExit, .min_keysize = AES_128_KEY_SIZE, .max_keysize = AES_256_KEY_SIZE, .ivsize = AES_BLOCK_SIZE, - .setkey = km_AesSetKey, + .setkey = km_AesCbcSetKey, .encrypt = km_AesCbcEncrypt, .decrypt = km_AesCbcDecrypt, }; @@ -289,6 +321,19 @@ static int cbcAesAlg_loaded = 0; #if defined(WOLFSSL_AES_CFB) && \ (defined(LINUXKM_LKCAPI_REGISTER_ALL) || defined(LINUXKM_LKCAPI_REGISTER_AESCFB)) +static int km_AesCfbInit(struct crypto_skcipher *tfm) +{ + struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm); + return km_AesInitCommon(ctx, WOLFKM_AESCFB_DRIVER, 0); +} + +static int km_AesCfbSetKey(struct crypto_skcipher *tfm, const u8 *in_key, + unsigned int key_len) +{ + struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm); + return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCFB_DRIVER); +} + static int km_AesCfbEncrypt(struct skcipher_request *req) { struct crypto_skcipher * tfm = NULL; @@ -303,15 +348,14 @@ static int km_AesCfbEncrypt(struct skcipher_request *req) err = skcipher_walk_virt(&walk, req, false); while ((nbytes = walk.nbytes)) { - err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv, - AES_ENCRYPTION); + err = wc_AesSetIV(ctx->aes_encrypt, walk.iv); if (unlikely(err)) { pr_err("wc_AesSetKey failed: %d\n", err); return err; } - err = wc_AesCfbEncrypt(ctx->aes, walk.dst.virt.addr, + err = wc_AesCfbEncrypt(ctx->aes_encrypt, walk.dst.virt.addr, walk.src.virt.addr, nbytes); if (unlikely(err)) { @@ -339,15 +383,14 @@ static int km_AesCfbDecrypt(struct skcipher_request *req) err = skcipher_walk_virt(&walk, req, false); while ((nbytes = walk.nbytes)) { - err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv, - AES_ENCRYPTION); + err = wc_AesSetIV(ctx->aes_encrypt, walk.iv); if (unlikely(err)) { pr_err("wc_AesSetKey failed"); return err; } - err = wc_AesCfbDecrypt(ctx->aes, walk.dst.virt.addr, + err = wc_AesCfbDecrypt(ctx->aes_encrypt, walk.dst.virt.addr, walk.src.virt.addr, nbytes); if (unlikely(err)) { @@ -368,12 +411,12 @@ static struct skcipher_alg cfbAesAlg = { .base.cra_blocksize = AES_BLOCK_SIZE, .base.cra_ctxsize = sizeof(struct km_AesCtx), .base.cra_module = THIS_MODULE, - .init = km_AesInit, + .init = km_AesCfbInit, .exit = km_AesExit, .min_keysize = AES_128_KEY_SIZE, .max_keysize = AES_256_KEY_SIZE, .ivsize = AES_BLOCK_SIZE, - .setkey = km_AesSetKey, + .setkey = km_AesCfbSetKey, .encrypt = km_AesCfbEncrypt, .decrypt = km_AesCfbDecrypt, }; @@ -390,8 +433,7 @@ static int cfbAesAlg_loaded = 0; static int km_AesGcmInit(struct crypto_aead * tfm) { struct km_AesCtx * ctx = crypto_aead_ctx(tfm); - km_ForceZero(ctx); - return km_AesInitCommon(ctx, WOLFKM_AESGCM_DRIVER); + return km_AesInitCommon(ctx, WOLFKM_AESGCM_DRIVER, 0); } static void km_AesGcmExit(struct crypto_aead * tfm) @@ -403,8 +445,16 @@ static void km_AesGcmExit(struct crypto_aead * tfm) static int km_AesGcmSetKey(struct crypto_aead *tfm, const u8 *in_key, unsigned int key_len) { + int err; struct km_AesCtx * ctx = crypto_aead_ctx(tfm); - return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESGCM_DRIVER); + + err = wc_AesGcmSetKey(ctx->aes_encrypt, in_key, key_len); + + if (err) { + pr_err("error: km_AesGcmSetKey %s failed: %d\n", WOLFKM_AESGCM_DRIVER, err); + } + + return err; } static int km_AesGcmSetAuthsize(struct crypto_aead *tfm, unsigned int authsize) @@ -454,7 +504,7 @@ static int km_AesGcmEncrypt(struct aead_request *req) return -1; } - err = wc_AesGcmInit(ctx->aes, ctx->key, ctx->keylen, walk.iv, + err = wc_AesGcmInit(ctx->aes_encrypt, NULL /* key */, 0 /* keylen */, walk.iv, AES_BLOCK_SIZE); if (unlikely(err)) { pr_err("error: wc_AesGcmInit failed with return code %d.\n", err); @@ -467,7 +517,7 @@ static int km_AesGcmEncrypt(struct aead_request *req) return err; } - err = wc_AesGcmEncryptUpdate(ctx->aes, NULL, NULL, 0, assoc, assocLeft); + err = wc_AesGcmEncryptUpdate(ctx->aes_encrypt, NULL, NULL, 0, assoc, assocLeft); assocLeft -= assocLeft; scatterwalk_unmap(assoc); assoc = NULL; @@ -483,7 +533,7 @@ static int km_AesGcmEncrypt(struct aead_request *req) if (likely(cryptLeft && nbytes)) { n = cryptLeft < nbytes ? cryptLeft : nbytes; - err = wc_AesGcmEncryptUpdate(ctx->aes, walk.dst.virt.addr, + err = wc_AesGcmEncryptUpdate(ctx->aes_encrypt, walk.dst.virt.addr, walk.src.virt.addr, cryptLeft, NULL, 0); nbytes -= n; cryptLeft -= n; @@ -497,7 +547,7 @@ static int km_AesGcmEncrypt(struct aead_request *req) err = skcipher_walk_done(&walk, nbytes); } - err = wc_AesGcmEncryptFinal(ctx->aes, authTag, tfm->authsize); + err = wc_AesGcmEncryptFinal(ctx->aes_encrypt, authTag, tfm->authsize); if (unlikely(err)) { pr_err("error: wc_AesGcmEncryptFinal failed with return code %d\n", err); return err; @@ -542,7 +592,7 @@ static int km_AesGcmDecrypt(struct aead_request *req) return -1; } - err = wc_AesGcmInit(ctx->aes, ctx->key, ctx->keylen, walk.iv, + err = wc_AesGcmInit(ctx->aes_encrypt, NULL /* key */, 0 /* keylen */, walk.iv, AES_BLOCK_SIZE); if (unlikely(err)) { pr_err("error: wc_AesGcmInit failed with return code %d.\n", err); @@ -555,7 +605,7 @@ static int km_AesGcmDecrypt(struct aead_request *req) return err; } - err = wc_AesGcmDecryptUpdate(ctx->aes, NULL, NULL, 0, assoc, assocLeft); + err = wc_AesGcmDecryptUpdate(ctx->aes_encrypt, NULL, NULL, 0, assoc, assocLeft); assocLeft -= assocLeft; scatterwalk_unmap(assoc); assoc = NULL; @@ -571,7 +621,7 @@ static int km_AesGcmDecrypt(struct aead_request *req) if (likely(cryptLeft && nbytes)) { n = cryptLeft < nbytes ? cryptLeft : nbytes; - err = wc_AesGcmDecryptUpdate(ctx->aes, walk.dst.virt.addr, + err = wc_AesGcmDecryptUpdate(ctx->aes_encrypt, walk.dst.virt.addr, walk.src.virt.addr, cryptLeft, NULL, 0); nbytes -= n; cryptLeft -= n; @@ -585,7 +635,7 @@ static int km_AesGcmDecrypt(struct aead_request *req) err = skcipher_walk_done(&walk, nbytes); } - err = wc_AesGcmDecryptFinal(ctx->aes, origAuthTag, tfm->authsize); + err = wc_AesGcmDecryptFinal(ctx->aes_encrypt, origAuthTag, tfm->authsize); if (unlikely(err)) { pr_err("error: wc_AesGcmDecryptFinal failed with return code %d\n", err);