linuxkm rsa: comments, cleanup work buffer useage.

This commit is contained in:
jordan
2025-05-22 11:07:36 -05:00
parent 54104887ca
commit 402ebec3b7

View File

@@ -1031,6 +1031,20 @@ static int km_pkcs1_sha3_512_init(struct tfm_type *tfm)
#endif /* WOLFSSL_SHA3 */ #endif /* WOLFSSL_SHA3 */
#if !defined(LINUXKM_AKCIPHER_NO_SIGNVERIFY) #if !defined(LINUXKM_AKCIPHER_NO_SIGNVERIFY)
/*
* Generates a pkcs1 encoded signature.
*
* src:
* - req->src scatterlist is the digest to be encoded, padded, and signed.
* - req->src_len + encoding + min padding must be <= key_size.
*
* dst:
* - req->dst scatterlist is destination signature.
* - req->dst_len must be >= key_len size.
*
* See kernel (6.12 or earlier):
* - include/crypto/akcipher.h
*/
static int km_pkcs1pad_sign(struct akcipher_request *req) static int km_pkcs1pad_sign(struct akcipher_request *req)
{ {
struct crypto_akcipher * tfm = NULL; struct crypto_akcipher * tfm = NULL;
@@ -1041,6 +1055,7 @@ static int km_pkcs1pad_sign(struct akcipher_request *req)
int hash_enc_len = 0; int hash_enc_len = 0;
byte * msg = NULL; byte * msg = NULL;
byte * sig = NULL; byte * sig = NULL;
byte * work_buffer = NULL;
if (req->src == NULL || req->dst == NULL) { if (req->src == NULL || req->dst == NULL) {
err = -EINVAL; err = -EINVAL;
@@ -1072,22 +1087,17 @@ static int km_pkcs1pad_sign(struct akcipher_request *req)
goto pkcs1pad_sign_out; goto pkcs1pad_sign_out;
} }
/* allocate extra space for encoding. */ work_buffer = malloc(2 * ctx->key_len);
msg = malloc(ctx->key_len); if (unlikely(work_buffer == NULL)) {
if (unlikely(msg == NULL)) {
err = -ENOMEM; err = -ENOMEM;
goto pkcs1pad_sign_out; goto pkcs1pad_sign_out;
} }
sig = malloc(ctx->key_len); memset(work_buffer, 0, 2 * ctx->key_len);
if (unlikely(sig == NULL)) { msg = work_buffer;
err = -ENOMEM; sig = work_buffer + ctx->key_len;
goto pkcs1pad_sign_out;
}
/* copy req->src to msg */ /* copy req->src to msg */
memset(msg, 0, ctx->key_len);
memset(sig, 0, ctx->key_len);
scatterwalk_map_and_copy(msg, req->src, 0, req->src_len, 0); scatterwalk_map_and_copy(msg, req->src, 0, req->src_len, 0);
/* encode message with hash oid. */ /* encode message with hash oid. */
@@ -1118,8 +1128,7 @@ static int km_pkcs1pad_sign(struct akcipher_request *req)
err = 0; err = 0;
pkcs1pad_sign_out: pkcs1pad_sign_out:
if (msg != NULL) { free(msg); msg = NULL; } if (work_buffer != NULL) { free(work_buffer); work_buffer = NULL; }
if (sig != NULL) { free(sig); sig = NULL; }
#ifdef WOLFKM_DEBUG_RSA #ifdef WOLFKM_DEBUG_RSA
pr_info("info: exiting km_pkcs1pad_sign msg_len %d, enc_msg_len %d," pr_info("info: exiting km_pkcs1pad_sign msg_len %d, enc_msg_len %d,"
@@ -1137,7 +1146,7 @@ pkcs1pad_sign_out:
* - dst_len: digest * - dst_len: digest
* *
* dst should be null. * dst should be null.
* See kernel: * See kernel (6.12 or earlier):
* - include/crypto/akcipher.h * - include/crypto/akcipher.h
*/ */
static int km_pkcs1pad_verify(struct akcipher_request *req) static int km_pkcs1pad_verify(struct akcipher_request *req)
@@ -1153,10 +1162,11 @@ static int km_pkcs1pad_verify(struct akcipher_request *req)
int n_diff = 0; int n_diff = 0;
byte * sig = NULL; byte * sig = NULL;
byte * msg = NULL; byte * msg = NULL;
byte * work_buffer = NULL;
if (req->src == NULL || req->dst != NULL) { if (req->src == NULL || req->dst != NULL) {
err = -EINVAL; err = -EINVAL;
goto pkcs1_verify_out; goto pkcs1pad_verify_out;
} }
tfm = crypto_akcipher_reqtfm(req); tfm = crypto_akcipher_reqtfm(req);
@@ -1168,7 +1178,7 @@ static int km_pkcs1pad_verify(struct akcipher_request *req)
if (ctx->key_len <= 0 || ctx->digest_len <= 0) { if (ctx->key_len <= 0 || ctx->digest_len <= 0) {
/* invalid key state */ /* invalid key state */
err = -EINVAL; err = -EINVAL;
goto pkcs1_verify_out; goto pkcs1pad_verify_out;
} }
hash_enc_len = get_hash_enc_len(ctx->hash_oid); hash_enc_len = get_hash_enc_len(ctx->hash_oid);
@@ -1178,7 +1188,7 @@ static int km_pkcs1pad_verify(struct akcipher_request *req)
WOLFKM_RSA_DRIVER, hash_enc_len); WOLFKM_RSA_DRIVER, hash_enc_len);
#endif /* WOLFKM_DEBUG_RSA */ #endif /* WOLFKM_DEBUG_RSA */
err = -EINVAL; err = -EINVAL;
goto pkcs1_verify_out; goto pkcs1pad_verify_out;
} }
if (msg_len != ctx->digest_len || sig_len != ctx->key_len) { if (msg_len != ctx->digest_len || sig_len != ctx->key_len) {
@@ -1188,25 +1198,20 @@ static int km_pkcs1pad_verify(struct akcipher_request *req)
WOLFKM_RSA_DRIVER, msg_len, ctx->digest_len); WOLFKM_RSA_DRIVER, msg_len, ctx->digest_len);
#endif /* WOLFKM_DEBUG_RSA */ #endif /* WOLFKM_DEBUG_RSA */
err = -EINVAL; err = -EINVAL;
goto pkcs1_verify_out; goto pkcs1pad_verify_out;
} }
sig = malloc(ctx->key_len); work_buffer = malloc(2 * ctx->key_len);
if (unlikely(sig == NULL)) { if (unlikely(work_buffer == NULL)) {
err = -ENOMEM; err = -ENOMEM;
goto pkcs1_verify_out; goto pkcs1pad_verify_out;
} }
/* allocate extra space for encoding. */ memset(work_buffer, 0, 2 * ctx->key_len);
msg = malloc(ctx->key_len); msg = work_buffer;
if (unlikely(msg == NULL)) { sig = work_buffer + ctx->key_len;
err = -ENOMEM;
goto pkcs1_verify_out;
}
/* copy sig from req->src to sig */ /* copy sig from req->src to sig */
memset(sig, 0, ctx->key_len);
memset(msg, 0, ctx->key_len);
scatterwalk_map_and_copy(sig, req->src, 0, sig_len, 0); scatterwalk_map_and_copy(sig, req->src, 0, sig_len, 0);
/* verify encoded message. */ /* verify encoded message. */
@@ -1217,7 +1222,7 @@ static int km_pkcs1pad_verify(struct akcipher_request *req)
WOLFKM_RSA_DRIVER, dec_len); WOLFKM_RSA_DRIVER, dec_len);
#endif /* WOLFKM_DEBUG_RSA */ #endif /* WOLFKM_DEBUG_RSA */
err = -EBADMSG; err = -EBADMSG;
goto pkcs1_verify_out; goto pkcs1pad_verify_out;
} }
/* reuse sig array for digest comparison */ /* reuse sig array for digest comparison */
@@ -1228,19 +1233,18 @@ static int km_pkcs1pad_verify(struct akcipher_request *req)
enc_msg_len = wc_EncodeSignature(sig, sig, msg_len, ctx->hash_oid); enc_msg_len = wc_EncodeSignature(sig, sig, msg_len, ctx->hash_oid);
if (unlikely(enc_msg_len <= 0 || enc_msg_len != dec_len)) { if (unlikely(enc_msg_len <= 0 || enc_msg_len != dec_len)) {
err = -EINVAL; err = -EINVAL;
goto pkcs1_verify_out; goto pkcs1pad_verify_out;
} }
n_diff = memcmp(sig, msg, dec_len); n_diff = memcmp(sig, msg, dec_len);
if (unlikely(n_diff != 0)) { if (unlikely(n_diff != 0)) {
err = -EKEYREJECTED; err = -EKEYREJECTED;
goto pkcs1_verify_out; goto pkcs1pad_verify_out;
} }
err = 0; err = 0;
pkcs1_verify_out: pkcs1pad_verify_out:
if (msg != NULL) { free(msg); msg = NULL; } if (work_buffer != NULL) { free(work_buffer); work_buffer = NULL; }
if (sig != NULL) { free(sig); sig = NULL; }
#ifdef WOLFKM_DEBUG_RSA #ifdef WOLFKM_DEBUG_RSA
pr_info("info: exiting km_pkcs1pad_verify msg_len %d, enc_msg_len %d," pr_info("info: exiting km_pkcs1pad_verify msg_len %d, enc_msg_len %d,"
@@ -1266,6 +1270,20 @@ static unsigned int km_pkcs1_key_size(struct crypto_sig *tfm)
return (unsigned int) ctx->key_len; return (unsigned int) ctx->key_len;
} }
/*
* Generates a pkcs1 encoded signature.
*
* src:
* - src contains the digest to be encoded, padded, and signed.
* - slen + encoding + min padding must be <= key_size.
*
* dst:
* - dst is destination signature buffer.
* - dlen must be >= key_len size.
*
* See kernel (6.13 or later):
* - include/crypto/sig.h
*/
static int km_pkcs1_sign(struct crypto_sig *tfm, static int km_pkcs1_sign(struct crypto_sig *tfm,
const void *src, unsigned int slen, const void *src, unsigned int slen,
void *dst, unsigned int dlen) void *dst, unsigned int dlen)
@@ -1273,10 +1291,11 @@ static int km_pkcs1_sign(struct crypto_sig *tfm,
struct km_rsa_ctx * ctx = NULL; struct km_rsa_ctx * ctx = NULL;
int err = 0; int err = 0;
word32 sig_len = 0; word32 sig_len = 0;
word32 enc_len = 0; word32 enc_msg_len = 0;
int hash_enc_len = 0; int hash_enc_len = 0;
byte * msg = NULL; byte * msg = NULL;
byte * sig = NULL; byte * sig = dst; /* reuse dst buffer. we will check if
* it is large enough. */
if (src == NULL || dst == NULL) { if (src == NULL || dst == NULL) {
err = -EINVAL; err = -EINVAL;
@@ -1314,30 +1333,24 @@ static int km_pkcs1_sign(struct crypto_sig *tfm,
goto pkcs1_sign_out; goto pkcs1_sign_out;
} }
sig = malloc(ctx->key_len); /* copy src to msg, and clear buffers. */
if (unlikely(sig == NULL)) {
err = -ENOMEM;
goto pkcs1_sign_out;
}
/* copy src to msg */
memset(msg, 0, ctx->key_len); memset(msg, 0, ctx->key_len);
memset(sig, 0, ctx->key_len); memset(sig, 0, ctx->key_len);
memmove(msg, src, slen); memcpy(msg, src, slen);
/* encode message with hash oid. */ /* encode message with hash oid. */
enc_len = wc_EncodeSignature(msg, msg, slen, ctx->hash_oid); enc_msg_len = wc_EncodeSignature(msg, msg, slen, ctx->hash_oid);
if (unlikely(enc_len <= 0)) { if (unlikely(enc_msg_len <= 0)) {
#ifdef WOLFKM_DEBUG_RSA #ifdef WOLFKM_DEBUG_RSA
pr_err("error: %s: wc_EncodeSignature returned: %d\n", pr_err("error: %s: wc_EncodeSignature returned: %d\n",
WOLFKM_RSA_DRIVER, enc_len); WOLFKM_RSA_DRIVER, enc_msg_len);
#endif /* WOLFKM_DEBUG_RSA */ #endif /* WOLFKM_DEBUG_RSA */
err = -EINVAL; err = -EINVAL;
goto pkcs1_sign_out; goto pkcs1_sign_out;
} }
/* sign encoded message. */ /* sign encoded message. */
sig_len = wc_RsaSSL_Sign(msg, enc_len, sig, sig_len = wc_RsaSSL_Sign(msg, enc_msg_len, sig,
ctx->key_len, ctx->key, &ctx->rng); ctx->key_len, ctx->key, &ctx->rng);
if (unlikely(sig_len != ctx->key_len)) { if (unlikely(sig_len != ctx->key_len)) {
#ifdef WOLFKM_DEBUG_RSA #ifdef WOLFKM_DEBUG_RSA
@@ -1348,26 +1361,34 @@ static int km_pkcs1_sign(struct crypto_sig *tfm,
goto pkcs1_sign_out; goto pkcs1_sign_out;
} }
/* copy sig to dst */
memmove(dst, sig, ctx->key_len);
err = 0; err = 0;
pkcs1_sign_out: pkcs1_sign_out:
if (msg != NULL) { free(msg); msg = NULL; } if (msg != NULL) { free(msg); msg = NULL; }
if (sig != NULL) { free(sig); sig = NULL; }
#ifdef WOLFKM_DEBUG_RSA #ifdef WOLFKM_DEBUG_RSA
pr_info("info: exiting km_pkcs1_sign msg_len %d, enc_msg_len %d," pr_info("info: exiting km_pkcs1_sign msg_len %d, enc_msg_len %d,"
" sig_len %d, err %d", slen, enc_len, sig_len, err); " sig_len %d, err %d", slen, enc_msg_len, sig_len, err);
#endif /* WOLFKM_DEBUG_RSA */ #endif /* WOLFKM_DEBUG_RSA */
return err; return err;
return 0;
} }
/*
* Verify a pkcs1 encoded signature.
*
* src:
* - src contains the signature.
* - slen must == key_size
*
* digest:
* - the digest that was encoded, padded, and signed previously.
* - dlen must be the correct digest len.
*
* See kernel (6.13 or later):
* - include/crypto/sig.h
*/
static int km_pkcs1_verify(struct crypto_sig *tfm, static int km_pkcs1_verify(struct crypto_sig *tfm,
const void *src_v, unsigned int slen, const void *src, unsigned int slen,
const void *digest, unsigned int dlen) const void *digest, unsigned int dlen)
{ {
struct km_rsa_ctx * ctx = NULL; struct km_rsa_ctx * ctx = NULL;
@@ -1378,9 +1399,9 @@ static int km_pkcs1_verify(struct crypto_sig *tfm,
word32 enc_msg_len = 0; word32 enc_msg_len = 0;
int hash_enc_len = 0; int hash_enc_len = 0;
int n_diff = 0; int n_diff = 0;
byte * sig = NULL; byte * enc_digest = NULL;
byte * msg = NULL; byte * msg = NULL;
const u8 * src = src_v; byte * work_buffer = NULL;
if (src == NULL || digest == NULL) { if (src == NULL || digest == NULL) {
err = -EINVAL; err = -EINVAL;
@@ -1418,26 +1439,18 @@ static int km_pkcs1_verify(struct crypto_sig *tfm,
goto pkcs1_verify_out; goto pkcs1_verify_out;
} }
sig = malloc(ctx->key_len); work_buffer = malloc(2 * ctx->key_len);
if (unlikely(sig == NULL)) { if (unlikely(work_buffer == NULL)) {
err = -ENOMEM; err = -ENOMEM;
goto pkcs1_verify_out; goto pkcs1_verify_out;
} }
/* allocate extra space for encoding. */ memset(work_buffer, 0, 2 * ctx->key_len);
msg = malloc(ctx->key_len); msg = work_buffer;
if (unlikely(msg == NULL)) { enc_digest = work_buffer + ctx->key_len;
err = -ENOMEM;
goto pkcs1_verify_out;
}
/* copy sig from src to sig */ /* verify encoded message. msg contains recovered original message. */
memset(sig, 0, ctx->key_len); dec_len = wc_RsaSSL_Verify(src, sig_len, msg, sig_len, ctx->key);
memset(msg, 0, ctx->key_len);
memmove(sig, src, sig_len);
/* verify encoded message. */
dec_len = wc_RsaSSL_Verify(sig, sig_len, msg, sig_len, ctx->key);
if (unlikely(dec_len <= 0)) { if (unlikely(dec_len <= 0)) {
#ifdef WOLFKM_DEBUG_RSA #ifdef WOLFKM_DEBUG_RSA
pr_err("error: %s: wc_RsaSSL_Verify returned: %d\n", pr_err("error: %s: wc_RsaSSL_Verify returned: %d\n",
@@ -1447,18 +1460,17 @@ static int km_pkcs1_verify(struct crypto_sig *tfm,
goto pkcs1_verify_out; goto pkcs1_verify_out;
} }
/* reuse sig array for digest comparison */ memcpy(enc_digest, digest, msg_len);
memset(sig, 0, ctx->key_len);
memmove(sig, digest, msg_len);
/* encode digest with hash oid. */ /* encode digest with hash oid. */
enc_msg_len = wc_EncodeSignature(sig, sig, msg_len, ctx->hash_oid); enc_msg_len = wc_EncodeSignature(enc_digest, enc_digest, msg_len,
ctx->hash_oid);
if (unlikely(enc_msg_len <= 0 || enc_msg_len != dec_len)) { if (unlikely(enc_msg_len <= 0 || enc_msg_len != dec_len)) {
err = -EINVAL; err = -EINVAL;
goto pkcs1_verify_out; goto pkcs1_verify_out;
} }
n_diff = memcmp(sig, msg, enc_msg_len); n_diff = memcmp(enc_digest, msg, enc_msg_len);
if (unlikely(n_diff != 0)) { if (unlikely(n_diff != 0)) {
#ifdef WOLFKM_DEBUG_RSA #ifdef WOLFKM_DEBUG_RSA
pr_err("error: %s: recovered msg did not match digest: %d\n", pr_err("error: %s: recovered msg did not match digest: %d\n",
@@ -1470,8 +1482,7 @@ static int km_pkcs1_verify(struct crypto_sig *tfm,
err = 0; err = 0;
pkcs1_verify_out: pkcs1_verify_out:
if (msg != NULL) { free(msg); msg = NULL; } if (work_buffer != NULL) { free(work_buffer); work_buffer = NULL; }
if (sig != NULL) { free(sig); sig = NULL; }
#ifdef WOLFKM_DEBUG_RSA #ifdef WOLFKM_DEBUG_RSA
pr_info("info: exiting km_pkcs1_verify msg_len %d, enc_msg_len %d," pr_info("info: exiting km_pkcs1_verify msg_len %d, enc_msg_len %d,"