linuxkm/lkcapi_glue.c: refactor AES-CBC, AES-CFB, and AES-GCM glue around struct km_AesCtx with separate aes_encrypt and aes_decrypt Aes pointers, and no cached key, to avoid AesSetKey operations at encrypt/decrypt time.

This commit is contained in:
Daniel Pouzzner
2024-01-27 23:16:02 -06:00
parent 8ae031a5ed
commit 957fc7460c

View File

@ -96,35 +96,57 @@ static int linuxkm_test_aesxts(void);
#include <wolfssl/wolfcrypt/aes.h> #include <wolfssl/wolfcrypt/aes.h>
struct km_AesCtx { struct km_AesCtx {
Aes *aes; /* must be pointer to control alignment, needed for AESNI. */ Aes *aes_encrypt; /* must be pointer to control alignment, needed for AESNI. */
u8 key[AES_MAX_KEY_SIZE / 8]; Aes *aes_decrypt; /* same. */
unsigned int keylen;
}; };
static inline void km_ForceZero(struct km_AesCtx * ctx)
{
memzero_explicit(ctx->key, sizeof(ctx->key));
ctx->keylen = 0;
}
#if defined(LINUXKM_LKCAPI_REGISTER_ALL) || \ #if defined(LINUXKM_LKCAPI_REGISTER_ALL) || \
defined(LINUXKM_LKCAPI_REGISTER_AESCBC) || \ defined(LINUXKM_LKCAPI_REGISTER_AESCBC) || \
defined(LINUXKM_LKCAPI_REGISTER_AESCFB) || \ defined(LINUXKM_LKCAPI_REGISTER_AESCFB) || \
defined(LINUXKM_LKCAPI_REGISTER_AESGCM) defined(LINUXKM_LKCAPI_REGISTER_AESGCM)
static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name) static void km_AesExitCommon(struct km_AesCtx * ctx);
static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name, int need_decryption)
{ {
int err; int err;
ctx->aes = (Aes *)malloc(sizeof(*ctx->aes)); ctx->aes_encrypt = (Aes *)malloc(sizeof(*ctx->aes_encrypt));
if (! ctx->aes) if (! ctx->aes_encrypt) {
pr_err("error: km_AesInitCommon %s failed: %d\n", name, MEMORY_E);
return MEMORY_E; return MEMORY_E;
}
err = wc_AesInit(ctx->aes, NULL, INVALID_DEVID); err = wc_AesInit(ctx->aes_encrypt, NULL, INVALID_DEVID);
if (unlikely(err)) { if (unlikely(err)) {
pr_err("error: km_AesInitCommon %s failed: %d\n", name, err); pr_err("error: km_AesInitCommon %s failed: %d\n", name, err);
free(ctx->aes_encrypt);
ctx->aes_encrypt = NULL;
return err;
}
if (! need_decryption) {
ctx->aes_decrypt = NULL;
return 0;
}
ctx->aes_decrypt = (Aes *)malloc(sizeof(*ctx->aes_decrypt));
if (! ctx->aes_encrypt) {
pr_err("error: km_AesInitCommon %s failed: %d\n", name, MEMORY_E);
km_AesExitCommon(ctx);
return MEMORY_E;
}
err = wc_AesInit(ctx->aes_decrypt, NULL, INVALID_DEVID);
if (unlikely(err)) {
pr_err("error: km_AesInitCommon %s failed: %d\n", name, err);
free(ctx->aes_decrypt);
ctx->aes_decrypt = NULL;
km_AesExitCommon(ctx);
return err; return err;
} }
@ -133,10 +155,16 @@ static int km_AesInitCommon(struct km_AesCtx * ctx, const char * name)
static void km_AesExitCommon(struct km_AesCtx * ctx) static void km_AesExitCommon(struct km_AesCtx * ctx)
{ {
wc_AesFree(ctx->aes); if (ctx->aes_encrypt) {
free(ctx->aes); wc_AesFree(ctx->aes_encrypt);
ctx->aes = NULL; free(ctx->aes_encrypt);
km_ForceZero(ctx); ctx->aes_encrypt = NULL;
}
if (ctx->aes_decrypt) {
wc_AesFree(ctx->aes_decrypt);
free(ctx->aes_decrypt);
ctx->aes_decrypt = NULL;
}
} }
static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key, static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key,
@ -144,15 +172,21 @@ static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key,
{ {
int err; int err;
err = wc_AesSetKey(ctx->aes, in_key, key_len, NULL, 0); err = wc_AesSetKey(ctx->aes_encrypt, in_key, key_len, NULL, AES_ENCRYPTION);
if (unlikely(err)) { if (unlikely(err)) {
pr_err("error: km_AesSetKeyCommon %s failed: %d\n", name, err); pr_err("error: km_AesSetKeyCommon %s failed: %d\n", name, err);
return err; return err;
} }
XMEMCPY(ctx->key, in_key, key_len); if (ctx->aes_decrypt) {
ctx->keylen = key_len; err = wc_AesSetKey(ctx->aes_decrypt, in_key, key_len, NULL, AES_DECRYPTION);
if (unlikely(err)) {
pr_err("error: km_AesSetKeyCommon %s failed: %d\n", name, err);
return err;
}
}
return 0; return 0;
} }
@ -161,25 +195,12 @@ static int km_AesSetKeyCommon(struct km_AesCtx * ctx, const u8 *in_key,
defined(LINUXKM_LKCAPI_REGISTER_AESCBC) || \ defined(LINUXKM_LKCAPI_REGISTER_AESCBC) || \
defined(LINUXKM_LKCAPI_REGISTER_AESCFB) defined(LINUXKM_LKCAPI_REGISTER_AESCFB)
static int km_AesInit(struct crypto_skcipher *tfm)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesInitCommon(ctx, WOLFKM_AESCBC_DRIVER);
}
static void km_AesExit(struct crypto_skcipher *tfm) static void km_AesExit(struct crypto_skcipher *tfm)
{ {
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm); struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
km_AesExitCommon(ctx); km_AesExitCommon(ctx);
} }
static int km_AesSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
unsigned int key_len)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCBC_DRIVER);
}
#endif /* LINUXKM_LKCAPI_REGISTER_ALL || #endif /* LINUXKM_LKCAPI_REGISTER_ALL ||
* LINUXKM_LKCAPI_REGISTER_AESCBC || * LINUXKM_LKCAPI_REGISTER_AESCBC ||
* LINUXKM_LKCAPI_REGISTER_AESCFB * LINUXKM_LKCAPI_REGISTER_AESCFB
@ -192,6 +213,19 @@ static int km_AesSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
#if defined(HAVE_AES_CBC) && \ #if defined(HAVE_AES_CBC) && \
(defined(LINUXKM_LKCAPI_REGISTER_ALL) || defined(LINUXKM_LKCAPI_REGISTER_AESCBC)) (defined(LINUXKM_LKCAPI_REGISTER_ALL) || defined(LINUXKM_LKCAPI_REGISTER_AESCBC))
static int km_AesCbcInit(struct crypto_skcipher *tfm)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesInitCommon(ctx, WOLFKM_AESCBC_DRIVER, 1);
}
static int km_AesCbcSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
unsigned int key_len)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCBC_DRIVER);
}
static int km_AesCbcEncrypt(struct skcipher_request *req) static int km_AesCbcEncrypt(struct skcipher_request *req)
{ {
struct crypto_skcipher * tfm = NULL; struct crypto_skcipher * tfm = NULL;
@ -206,15 +240,14 @@ static int km_AesCbcEncrypt(struct skcipher_request *req)
err = skcipher_walk_virt(&walk, req, false); err = skcipher_walk_virt(&walk, req, false);
while ((nbytes = walk.nbytes)) { while ((nbytes = walk.nbytes)) {
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv, err = wc_AesSetIV(ctx->aes_encrypt, walk.iv);
AES_ENCRYPTION);
if (unlikely(err)) { if (unlikely(err)) {
pr_err("wc_AesSetKey failed: %d\n", err); pr_err("wc_AesSetIV failed: %d\n", err);
return err; return err;
} }
err = wc_AesCbcEncrypt(ctx->aes, walk.dst.virt.addr, err = wc_AesCbcEncrypt(ctx->aes_encrypt, walk.dst.virt.addr,
walk.src.virt.addr, nbytes); walk.src.virt.addr, nbytes);
if (unlikely(err)) { if (unlikely(err)) {
@ -242,15 +275,14 @@ static int km_AesCbcDecrypt(struct skcipher_request *req)
err = skcipher_walk_virt(&walk, req, false); err = skcipher_walk_virt(&walk, req, false);
while ((nbytes = walk.nbytes)) { while ((nbytes = walk.nbytes)) {
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv, err = wc_AesSetIV(ctx->aes_decrypt, walk.iv);
AES_DECRYPTION);
if (unlikely(err)) { if (unlikely(err)) {
pr_err("wc_AesSetKey failed"); pr_err("wc_AesSetKey failed");
return err; return err;
} }
err = wc_AesCbcDecrypt(ctx->aes, walk.dst.virt.addr, err = wc_AesCbcDecrypt(ctx->aes_decrypt, walk.dst.virt.addr,
walk.src.virt.addr, nbytes); walk.src.virt.addr, nbytes);
if (unlikely(err)) { if (unlikely(err)) {
@ -271,12 +303,12 @@ static struct skcipher_alg cbcAesAlg = {
.base.cra_blocksize = AES_BLOCK_SIZE, .base.cra_blocksize = AES_BLOCK_SIZE,
.base.cra_ctxsize = sizeof(struct km_AesCtx), .base.cra_ctxsize = sizeof(struct km_AesCtx),
.base.cra_module = THIS_MODULE, .base.cra_module = THIS_MODULE,
.init = km_AesInit, .init = km_AesCbcInit,
.exit = km_AesExit, .exit = km_AesExit,
.min_keysize = AES_128_KEY_SIZE, .min_keysize = AES_128_KEY_SIZE,
.max_keysize = AES_256_KEY_SIZE, .max_keysize = AES_256_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE, .ivsize = AES_BLOCK_SIZE,
.setkey = km_AesSetKey, .setkey = km_AesCbcSetKey,
.encrypt = km_AesCbcEncrypt, .encrypt = km_AesCbcEncrypt,
.decrypt = km_AesCbcDecrypt, .decrypt = km_AesCbcDecrypt,
}; };
@ -289,6 +321,19 @@ static int cbcAesAlg_loaded = 0;
#if defined(WOLFSSL_AES_CFB) && \ #if defined(WOLFSSL_AES_CFB) && \
(defined(LINUXKM_LKCAPI_REGISTER_ALL) || defined(LINUXKM_LKCAPI_REGISTER_AESCFB)) (defined(LINUXKM_LKCAPI_REGISTER_ALL) || defined(LINUXKM_LKCAPI_REGISTER_AESCFB))
static int km_AesCfbInit(struct crypto_skcipher *tfm)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesInitCommon(ctx, WOLFKM_AESCFB_DRIVER, 0);
}
static int km_AesCfbSetKey(struct crypto_skcipher *tfm, const u8 *in_key,
unsigned int key_len)
{
struct km_AesCtx * ctx = crypto_skcipher_ctx(tfm);
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESCFB_DRIVER);
}
static int km_AesCfbEncrypt(struct skcipher_request *req) static int km_AesCfbEncrypt(struct skcipher_request *req)
{ {
struct crypto_skcipher * tfm = NULL; struct crypto_skcipher * tfm = NULL;
@ -303,15 +348,14 @@ static int km_AesCfbEncrypt(struct skcipher_request *req)
err = skcipher_walk_virt(&walk, req, false); err = skcipher_walk_virt(&walk, req, false);
while ((nbytes = walk.nbytes)) { while ((nbytes = walk.nbytes)) {
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv, err = wc_AesSetIV(ctx->aes_encrypt, walk.iv);
AES_ENCRYPTION);
if (unlikely(err)) { if (unlikely(err)) {
pr_err("wc_AesSetKey failed: %d\n", err); pr_err("wc_AesSetKey failed: %d\n", err);
return err; return err;
} }
err = wc_AesCfbEncrypt(ctx->aes, walk.dst.virt.addr, err = wc_AesCfbEncrypt(ctx->aes_encrypt, walk.dst.virt.addr,
walk.src.virt.addr, nbytes); walk.src.virt.addr, nbytes);
if (unlikely(err)) { if (unlikely(err)) {
@ -339,15 +383,14 @@ static int km_AesCfbDecrypt(struct skcipher_request *req)
err = skcipher_walk_virt(&walk, req, false); err = skcipher_walk_virt(&walk, req, false);
while ((nbytes = walk.nbytes)) { while ((nbytes = walk.nbytes)) {
err = wc_AesSetKey(ctx->aes, ctx->key, ctx->keylen, walk.iv, err = wc_AesSetIV(ctx->aes_encrypt, walk.iv);
AES_ENCRYPTION);
if (unlikely(err)) { if (unlikely(err)) {
pr_err("wc_AesSetKey failed"); pr_err("wc_AesSetKey failed");
return err; return err;
} }
err = wc_AesCfbDecrypt(ctx->aes, walk.dst.virt.addr, err = wc_AesCfbDecrypt(ctx->aes_encrypt, walk.dst.virt.addr,
walk.src.virt.addr, nbytes); walk.src.virt.addr, nbytes);
if (unlikely(err)) { if (unlikely(err)) {
@ -368,12 +411,12 @@ static struct skcipher_alg cfbAesAlg = {
.base.cra_blocksize = AES_BLOCK_SIZE, .base.cra_blocksize = AES_BLOCK_SIZE,
.base.cra_ctxsize = sizeof(struct km_AesCtx), .base.cra_ctxsize = sizeof(struct km_AesCtx),
.base.cra_module = THIS_MODULE, .base.cra_module = THIS_MODULE,
.init = km_AesInit, .init = km_AesCfbInit,
.exit = km_AesExit, .exit = km_AesExit,
.min_keysize = AES_128_KEY_SIZE, .min_keysize = AES_128_KEY_SIZE,
.max_keysize = AES_256_KEY_SIZE, .max_keysize = AES_256_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE, .ivsize = AES_BLOCK_SIZE,
.setkey = km_AesSetKey, .setkey = km_AesCfbSetKey,
.encrypt = km_AesCfbEncrypt, .encrypt = km_AesCfbEncrypt,
.decrypt = km_AesCfbDecrypt, .decrypt = km_AesCfbDecrypt,
}; };
@ -390,8 +433,7 @@ static int cfbAesAlg_loaded = 0;
static int km_AesGcmInit(struct crypto_aead * tfm) static int km_AesGcmInit(struct crypto_aead * tfm)
{ {
struct km_AesCtx * ctx = crypto_aead_ctx(tfm); struct km_AesCtx * ctx = crypto_aead_ctx(tfm);
km_ForceZero(ctx); return km_AesInitCommon(ctx, WOLFKM_AESGCM_DRIVER, 0);
return km_AesInitCommon(ctx, WOLFKM_AESGCM_DRIVER);
} }
static void km_AesGcmExit(struct crypto_aead * tfm) static void km_AesGcmExit(struct crypto_aead * tfm)
@ -403,8 +445,16 @@ static void km_AesGcmExit(struct crypto_aead * tfm)
static int km_AesGcmSetKey(struct crypto_aead *tfm, const u8 *in_key, static int km_AesGcmSetKey(struct crypto_aead *tfm, const u8 *in_key,
unsigned int key_len) unsigned int key_len)
{ {
int err;
struct km_AesCtx * ctx = crypto_aead_ctx(tfm); struct km_AesCtx * ctx = crypto_aead_ctx(tfm);
return km_AesSetKeyCommon(ctx, in_key, key_len, WOLFKM_AESGCM_DRIVER);
err = wc_AesGcmSetKey(ctx->aes_encrypt, in_key, key_len);
if (err) {
pr_err("error: km_AesGcmSetKey %s failed: %d\n", WOLFKM_AESGCM_DRIVER, err);
}
return err;
} }
static int km_AesGcmSetAuthsize(struct crypto_aead *tfm, unsigned int authsize) static int km_AesGcmSetAuthsize(struct crypto_aead *tfm, unsigned int authsize)
@ -454,7 +504,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
return -1; return -1;
} }
err = wc_AesGcmInit(ctx->aes, ctx->key, ctx->keylen, walk.iv, err = wc_AesGcmInit(ctx->aes_encrypt, NULL /* key */, 0 /* keylen */, walk.iv,
AES_BLOCK_SIZE); AES_BLOCK_SIZE);
if (unlikely(err)) { if (unlikely(err)) {
pr_err("error: wc_AesGcmInit failed with return code %d.\n", err); pr_err("error: wc_AesGcmInit failed with return code %d.\n", err);
@ -467,7 +517,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
return err; return err;
} }
err = wc_AesGcmEncryptUpdate(ctx->aes, NULL, NULL, 0, assoc, assocLeft); err = wc_AesGcmEncryptUpdate(ctx->aes_encrypt, NULL, NULL, 0, assoc, assocLeft);
assocLeft -= assocLeft; assocLeft -= assocLeft;
scatterwalk_unmap(assoc); scatterwalk_unmap(assoc);
assoc = NULL; assoc = NULL;
@ -483,7 +533,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
if (likely(cryptLeft && nbytes)) { if (likely(cryptLeft && nbytes)) {
n = cryptLeft < nbytes ? cryptLeft : nbytes; n = cryptLeft < nbytes ? cryptLeft : nbytes;
err = wc_AesGcmEncryptUpdate(ctx->aes, walk.dst.virt.addr, err = wc_AesGcmEncryptUpdate(ctx->aes_encrypt, walk.dst.virt.addr,
walk.src.virt.addr, cryptLeft, NULL, 0); walk.src.virt.addr, cryptLeft, NULL, 0);
nbytes -= n; nbytes -= n;
cryptLeft -= n; cryptLeft -= n;
@ -497,7 +547,7 @@ static int km_AesGcmEncrypt(struct aead_request *req)
err = skcipher_walk_done(&walk, nbytes); err = skcipher_walk_done(&walk, nbytes);
} }
err = wc_AesGcmEncryptFinal(ctx->aes, authTag, tfm->authsize); err = wc_AesGcmEncryptFinal(ctx->aes_encrypt, authTag, tfm->authsize);
if (unlikely(err)) { if (unlikely(err)) {
pr_err("error: wc_AesGcmEncryptFinal failed with return code %d\n", err); pr_err("error: wc_AesGcmEncryptFinal failed with return code %d\n", err);
return err; return err;
@ -542,7 +592,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
return -1; return -1;
} }
err = wc_AesGcmInit(ctx->aes, ctx->key, ctx->keylen, walk.iv, err = wc_AesGcmInit(ctx->aes_encrypt, NULL /* key */, 0 /* keylen */, walk.iv,
AES_BLOCK_SIZE); AES_BLOCK_SIZE);
if (unlikely(err)) { if (unlikely(err)) {
pr_err("error: wc_AesGcmInit failed with return code %d.\n", err); pr_err("error: wc_AesGcmInit failed with return code %d.\n", err);
@ -555,7 +605,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
return err; return err;
} }
err = wc_AesGcmDecryptUpdate(ctx->aes, NULL, NULL, 0, assoc, assocLeft); err = wc_AesGcmDecryptUpdate(ctx->aes_encrypt, NULL, NULL, 0, assoc, assocLeft);
assocLeft -= assocLeft; assocLeft -= assocLeft;
scatterwalk_unmap(assoc); scatterwalk_unmap(assoc);
assoc = NULL; assoc = NULL;
@ -571,7 +621,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
if (likely(cryptLeft && nbytes)) { if (likely(cryptLeft && nbytes)) {
n = cryptLeft < nbytes ? cryptLeft : nbytes; n = cryptLeft < nbytes ? cryptLeft : nbytes;
err = wc_AesGcmDecryptUpdate(ctx->aes, walk.dst.virt.addr, err = wc_AesGcmDecryptUpdate(ctx->aes_encrypt, walk.dst.virt.addr,
walk.src.virt.addr, cryptLeft, NULL, 0); walk.src.virt.addr, cryptLeft, NULL, 0);
nbytes -= n; nbytes -= n;
cryptLeft -= n; cryptLeft -= n;
@ -585,7 +635,7 @@ static int km_AesGcmDecrypt(struct aead_request *req)
err = skcipher_walk_done(&walk, nbytes); err = skcipher_walk_done(&walk, nbytes);
} }
err = wc_AesGcmDecryptFinal(ctx->aes, origAuthTag, tfm->authsize); err = wc_AesGcmDecryptFinal(ctx->aes_encrypt, origAuthTag, tfm->authsize);
if (unlikely(err)) { if (unlikely(err)) {
pr_err("error: wc_AesGcmDecryptFinal failed with return code %d\n", err); pr_err("error: wc_AesGcmDecryptFinal failed with return code %d\n", err);