From 54cb1d8d38deff2506fcbf1e77e90f15c3c94d50 Mon Sep 17 00:00:00 2001 From: Josh Holtrop Date: Mon, 27 Apr 2026 22:31:56 -0400 Subject: [PATCH] Rust wrapper: ensure memory safety for C RNG struct - store pointer to WC_RNG instead of full struct - enforce RNG is not dropped before consumer structs The C library stores a pointer via the set_rng() methods on a few structs (e.g. RSA). This change holds a reference (or instance) of RNG within the consumer structs to ensure it is kept alive if set_rng (or now set_shared_rng) is used. --- wrapper/rust/wolfssl-wolfcrypt/Cargo.toml | 2 +- wrapper/rust/wolfssl-wolfcrypt/Makefile | 2 +- .../rust/wolfssl-wolfcrypt/src/curve25519.rs | 101 ++++++-- wrapper/rust/wolfssl-wolfcrypt/src/dh.rs | 16 +- .../rust/wolfssl-wolfcrypt/src/dilithium.rs | 18 +- wrapper/rust/wolfssl-wolfcrypt/src/ecc.rs | 177 +++++++++---- wrapper/rust/wolfssl-wolfcrypt/src/ecdsa.rs | 2 +- wrapper/rust/wolfssl-wolfcrypt/src/ed25519.rs | 6 +- wrapper/rust/wolfssl-wolfcrypt/src/ed448.rs | 6 +- wrapper/rust/wolfssl-wolfcrypt/src/lib.rs | 3 + wrapper/rust/wolfssl-wolfcrypt/src/lms.rs | 4 +- wrapper/rust/wolfssl-wolfcrypt/src/mlkem.rs | 10 +- wrapper/rust/wolfssl-wolfcrypt/src/random.rs | 78 +++--- wrapper/rust/wolfssl-wolfcrypt/src/rsa.rs | 245 +++++++++++------- .../wolfssl-wolfcrypt/src/rsa_pkcs1v15.rs | 6 +- .../tests/test_curve25519.rs | 28 +- .../rust/wolfssl-wolfcrypt/tests/test_ecc.rs | 14 +- .../wolfssl-wolfcrypt/tests/test_random.rs | 8 +- .../rust/wolfssl-wolfcrypt/tests/test_rsa.rs | 32 +-- 19 files changed, 497 insertions(+), 261 deletions(-) diff --git a/wrapper/rust/wolfssl-wolfcrypt/Cargo.toml b/wrapper/rust/wolfssl-wolfcrypt/Cargo.toml index 0e3ceed443..6a38933619 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/Cargo.toml +++ b/wrapper/rust/wolfssl-wolfcrypt/Cargo.toml @@ -11,7 +11,7 @@ categories = ["cryptography", "security", "api-bindings"] readme = "README.md" [features] -std = [] +alloc = [] rand_core = ["dep:rand_core"] aead = ["dep:aead"] cipher = ["dep:cipher"] diff --git a/wrapper/rust/wolfssl-wolfcrypt/Makefile b/wrapper/rust/wolfssl-wolfcrypt/Makefile index f705c4f8c0..0b7eac1d47 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/Makefile +++ b/wrapper/rust/wolfssl-wolfcrypt/Makefile @@ -1,4 +1,4 @@ -FEATURES := rand_core,aead,cipher,digest,mac,signature,password-hash,kem +FEATURES := alloc,rand_core,aead,cipher,digest,mac,signature,password-hash,kem CARGO_FEATURE_FLAGS := --features $(FEATURES) .PHONY: all diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/curve25519.rs b/wrapper/rust/wolfssl-wolfcrypt/src/curve25519.rs index ef3c3ca8bd..337c57dfed 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/curve25519.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/curve25519.rs @@ -26,12 +26,15 @@ functionality. #![cfg(curve25519)] #[cfg(random)] -use crate::random::RNG; +use crate::random::{RNG, RngHandle}; use crate::sys; use core::mem::MaybeUninit; pub struct Curve25519Key { wc_key: sys::curve25519_key, + /// RNG bound via `set_rng`, kept alive while the C struct holds its pointer. + #[cfg(random)] + rng: Option, } impl Curve25519Key { @@ -73,7 +76,7 @@ impl Curve25519Key { /// Returns either Ok(curve25519key) on success or Err(e) containing the /// wolfSSL library error code value. #[cfg(random)] - pub fn generate(rng: &mut RNG) -> Result { + pub fn generate(rng: &RNG) -> Result { let mut wc_key: MaybeUninit = MaybeUninit::uninit(); let rc = unsafe { sys::wc_curve25519_init(wc_key.as_mut_ptr()) @@ -82,9 +85,13 @@ impl Curve25519Key { return Err(rc); } let wc_key = unsafe { wc_key.assume_init() }; - let mut curve25519key = Curve25519Key { wc_key }; + let mut curve25519key = Curve25519Key { + wc_key, + #[cfg(random)] + rng: None, + }; let rc = unsafe { - sys::wc_curve25519_make_key(&mut rng.wc_rng, Self::KEYSIZE as i32, + sys::wc_curve25519_make_key(rng.wc_rng, Self::KEYSIZE as i32, &mut curve25519key.wc_key) }; if rc != 0 { @@ -104,12 +111,12 @@ impl Curve25519Key { /// Returns either Ok(()) on success or Err(e) containing the wolfSSL /// library error code value. #[cfg(random)] - pub fn generate_priv(rng: &mut RNG, out: &mut [u8]) -> Result<(), i32> { + pub fn generate_priv(rng: &RNG, out: &mut [u8]) -> Result<(), i32> { if out.len() != Self::KEYSIZE { return Err(sys::wolfCrypt_ErrorCodes_BUFFER_E); } let rc = unsafe { - sys::wc_curve25519_make_priv(&mut rng.wc_rng, Self::KEYSIZE as i32, out.as_mut_ptr()) + sys::wc_curve25519_make_priv(rng.wc_rng, Self::KEYSIZE as i32, out.as_mut_ptr()) }; if rc != 0 { return Err(rc); @@ -137,7 +144,11 @@ impl Curve25519Key { return Err(rc); } let wc_key = unsafe { wc_key.assume_init() }; - let mut curve25519key = Curve25519Key { wc_key }; + let mut curve25519key = Curve25519Key { + wc_key, + #[cfg(random)] + rng: None, + }; let rc = unsafe { sys::wc_curve25519_import_private(private.as_ptr(), private_size, &mut curve25519key.wc_key) @@ -169,7 +180,11 @@ impl Curve25519Key { return Err(rc); } let wc_key = unsafe { wc_key.assume_init() }; - let mut curve25519key = Curve25519Key { wc_key }; + let mut curve25519key = Curve25519Key { + wc_key, + #[cfg(random)] + rng: None, + }; let endian = if big_endian {sys::EC25519_BIG_ENDIAN} else {sys::EC25519_LITTLE_ENDIAN}; let rc = unsafe { sys::wc_curve25519_import_private_ex(private.as_ptr(), @@ -203,7 +218,11 @@ impl Curve25519Key { return Err(rc); } let wc_key = unsafe { wc_key.assume_init() }; - let mut curve25519key = Curve25519Key { wc_key }; + let mut curve25519key = Curve25519Key { + wc_key, + #[cfg(random)] + rng: None, + }; let rc = unsafe { sys::wc_curve25519_import_private_raw(private.as_ptr(), private_size, public.as_ptr(), public_size, @@ -238,7 +257,11 @@ impl Curve25519Key { return Err(rc); } let wc_key = unsafe { wc_key.assume_init() }; - let mut curve25519key = Curve25519Key { wc_key }; + let mut curve25519key = Curve25519Key { + wc_key, + #[cfg(random)] + rng: None, + }; let endian = if big_endian {sys::EC25519_BIG_ENDIAN} else {sys::EC25519_LITTLE_ENDIAN}; let rc = unsafe { sys::wc_curve25519_import_private_raw_ex(private.as_ptr(), @@ -271,7 +294,11 @@ impl Curve25519Key { return Err(rc); } let wc_key = unsafe { wc_key.assume_init() }; - let mut curve25519key = Curve25519Key { wc_key }; + let mut curve25519key = Curve25519Key { + wc_key, + #[cfg(random)] + rng: None, + }; let rc = unsafe { sys::wc_curve25519_import_public(public.as_ptr(), public_size, &mut curve25519key.wc_key) @@ -303,7 +330,11 @@ impl Curve25519Key { return Err(rc); } let wc_key = unsafe { wc_key.assume_init() }; - let mut curve25519key = Curve25519Key { wc_key }; + let mut curve25519key = Curve25519Key { + wc_key, + #[cfg(random)] + rng: None, + }; let endian = if big_endian {sys::EC25519_BIG_ENDIAN} else {sys::EC25519_LITTLE_ENDIAN}; let rc = unsafe { sys::wc_curve25519_import_public_ex(public.as_ptr(), public_size, @@ -353,12 +384,12 @@ impl Curve25519Key { /// Returns either Ok(()) on success or Err(e) containing the wolfSSL /// library error code value. #[cfg(all(curve25519_blinding, random))] - pub fn make_pub_blind(private: &[u8], public: &mut [u8], rng: &mut RNG) -> Result<(), i32> { + pub fn make_pub_blind(private: &[u8], public: &mut [u8], rng: &RNG) -> Result<(), i32> { let private_size = crate::buffer_len_to_i32(private.len())?; let public_size = crate::buffer_len_to_i32(public.len())?; let rc = unsafe { sys::wc_curve25519_make_pub_blind(public_size, public.as_mut_ptr(), - private_size, private.as_ptr(), &mut rng.wc_rng) + private_size, private.as_ptr(), rng.wc_rng) }; if rc != 0 { return Err(rc); @@ -408,14 +439,14 @@ impl Curve25519Key { /// Returns either Ok(()) on success or Err(e) containing the wolfSSL /// library error code value. #[cfg(all(curve25519_blinding, random))] - pub fn make_pub_generic_blind(private: &[u8], public: &mut [u8], basepoint: &[u8], rng: &mut RNG) -> Result<(), i32> { + pub fn make_pub_generic_blind(private: &[u8], public: &mut [u8], basepoint: &[u8], rng: &RNG) -> Result<(), i32> { let private_size = crate::buffer_len_to_i32(private.len())?; let public_size = crate::buffer_len_to_i32(public.len())?; let basepoint_size = crate::buffer_len_to_i32(basepoint.len())?; let rc = unsafe { sys::wc_curve25519_generic_blind(public_size, public.as_mut_ptr(), private_size, private.as_ptr(), basepoint_size, basepoint.as_ptr(), - &mut rng.wc_rng) + rng.wc_rng) }; if rc != 0 { return Err(rc); @@ -454,27 +485,57 @@ impl Curve25519Key { /// This is necessary when generating a shared secret if wolfSSL is built /// with the `WOLFSSL_CURVE25519_BLINDING` build option enabled. /// + /// The key takes ownership of the RNG, so the underlying `WC_RNG` is + /// guaranteed to outlive this key. + /// /// # Parameters /// /// * `rng`: The `RNG` struct instance to associate with this - /// `Curve25519Key` instance. The `RNG` struct should not be moved in - /// memory after calling this method. + /// `Curve25519Key` instance. /// /// # Returns /// /// Returns Ok(()) on success or Err(e) containing the wolfSSL library /// error code value. #[cfg(all(curve25519_blinding, random))] - pub fn set_rng(&mut self, rng: &mut RNG) -> Result<(), i32> { + pub fn set_rng(&mut self, rng: RNG) -> Result<(), i32> { + let wc_rng = rng.wc_rng; let rc = unsafe { - sys::wc_curve25519_set_rng(&mut self.wc_key, &mut rng.wc_rng) + sys::wc_curve25519_set_rng(&mut self.wc_key, wc_rng) }; if rc != 0 { return Err(rc); } + self.rng = Some(RngHandle::Owned(rng)); Ok(()) } + /// Bind a shared `RNG` to this key. Available when the `alloc` feature + /// is enabled. + #[cfg(all(curve25519_blinding, random, feature = "alloc"))] + pub fn set_shared_rng(&mut self, rng: alloc::sync::Arc) -> Result<(), i32> { + let wc_rng = rng.wc_rng; + let rc = unsafe { + sys::wc_curve25519_set_rng(&mut self.wc_key, wc_rng) + }; + if rc != 0 { + return Err(rc); + } + self.rng = Some(RngHandle::Shared(rng)); + Ok(()) + } + + /// Borrow the RNG previously bound via `set_rng` or `set_shared_rng`. + #[cfg(random)] + pub fn rng(&self) -> Option<&RNG> { + match &self.rng { + Some(RngHandle::Owned(rng)) => Some(rng), + #[cfg(feature = "alloc")] + Some(RngHandle::Shared(rng)) => Some(rng), + None => None, + } + } + /// Compute a shared secret key given a secret private key and a received /// public key. It stores the generated secret key in the buffer out and /// returns the generated key size. Supports big or little endian. diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/dh.rs b/wrapper/rust/wolfssl-wolfcrypt/src/dh.rs index 36c4fca3fa..c6f2140f37 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/dh.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/dh.rs @@ -194,7 +194,7 @@ impl DH { /// } /// ``` #[cfg(all(dh_keygen, random))] - pub fn generate(rng: &mut RNG, modulus_size: i32) -> Result { + pub fn generate(rng: &RNG, modulus_size: i32) -> Result { Self::generate_ex(rng, modulus_size, None, None) } @@ -225,7 +225,7 @@ impl DH { /// } /// ``` #[cfg(all(dh_keygen, random))] - pub fn generate_ex(rng: &mut RNG, modulus_size: i32, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { + pub fn generate_ex(rng: &RNG, modulus_size: i32, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { let mut wc_dhkey: MaybeUninit = MaybeUninit::uninit(); let heap = match heap { Some(heap) => heap, @@ -242,7 +242,7 @@ impl DH { let wc_dhkey = unsafe { wc_dhkey.assume_init() }; let mut dh = DH { wc_dhkey }; let rc = unsafe { - sys::wc_DhGenerateParams(&mut rng.wc_rng, modulus_size, &mut dh.wc_dhkey) + sys::wc_DhGenerateParams(rng.wc_rng, modulus_size, &mut dh.wc_dhkey) }; if rc != 0 { return Err(rc); @@ -921,7 +921,7 @@ impl DH { /// } /// ``` #[cfg(random)] - pub fn new_from_pgq_with_check(p: &[u8], g: &[u8], q: &[u8], trusted: i32, rng: &mut RNG) -> Result { + pub fn new_from_pgq_with_check(p: &[u8], g: &[u8], q: &[u8], trusted: i32, rng: &RNG) -> Result { Self::new_from_pgq_with_check_ex(p, g, q, trusted, rng, None, None) } @@ -1030,7 +1030,7 @@ impl DH { /// } /// ``` #[cfg(random)] - pub fn new_from_pgq_with_check_ex(p: &[u8], g: &[u8], q: &[u8], trusted: i32, rng: &mut RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { + pub fn new_from_pgq_with_check_ex(p: &[u8], g: &[u8], q: &[u8], trusted: i32, rng: &RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { let p_size = crate::buffer_len_to_u32(p.len())?; let g_size = crate::buffer_len_to_u32(g.len())?; let q_size = crate::buffer_len_to_u32(q.len())?; @@ -1050,7 +1050,7 @@ impl DH { let wc_dhkey = unsafe { wc_dhkey.assume_init() }; let mut dh = DH { wc_dhkey }; let rc = unsafe { - sys::wc_DhSetCheckKey(&mut dh.wc_dhkey, p.as_ptr(), p_size, g.as_ptr(), g_size, q.as_ptr(), q_size, trusted, &mut rng.wc_rng) + sys::wc_DhSetCheckKey(&mut dh.wc_dhkey, p.as_ptr(), p_size, g.as_ptr(), g_size, q.as_ptr(), q_size, trusted, rng.wc_rng) }; if rc != 0 { return Err(rc); @@ -1509,13 +1509,13 @@ impl DH { /// } /// ``` #[cfg(random)] - pub fn generate_key_pair(&mut self, rng: &mut RNG, + pub fn generate_key_pair(&mut self, rng: &RNG, private: &mut [u8], private_size: &mut u32, public: &mut [u8], public_size: &mut u32) -> Result<(), i32> { *private_size = crate::buffer_len_to_u32(private.len())?; *public_size = crate::buffer_len_to_u32(public.len())?; let rc = unsafe { - sys::wc_DhGenerateKeyPair(&mut self.wc_dhkey, &mut rng.wc_rng, + sys::wc_DhGenerateKeyPair(&mut self.wc_dhkey, rng.wc_rng, private.as_mut_ptr(), private_size, public.as_mut_ptr(), public_size) }; diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/dilithium.rs b/wrapper/rust/wolfssl-wolfcrypt/src/dilithium.rs index 0005147859..6e1ac0564c 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/dilithium.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/dilithium.rs @@ -159,7 +159,7 @@ impl Dilithium { /// } /// ``` #[cfg(all(dilithium_make_key, random))] - pub fn generate(level: u8, rng: &mut RNG) -> Result { + pub fn generate(level: u8, rng: &RNG) -> Result { Self::generate_ex(level, rng, None, None) } @@ -193,7 +193,7 @@ impl Dilithium { #[cfg(all(dilithium_make_key, random))] pub fn generate_ex( level: u8, - rng: &mut RNG, + rng: &RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option, ) -> Result { @@ -202,7 +202,7 @@ impl Dilithium { if rc != 0 { return Err(rc); } - let rc = unsafe { sys::wc_dilithium_make_key(&mut key.ws_key, &mut rng.wc_rng) }; + let rc = unsafe { sys::wc_dilithium_make_key(&mut key.ws_key, rng.wc_rng) }; if rc != 0 { return Err(rc); } @@ -859,7 +859,7 @@ impl Dilithium { &mut self, msg: &[u8], sig: &mut [u8], - rng: &mut RNG, + rng: &RNG, ) -> Result { let msg_len = crate::buffer_len_to_u32(msg.len())?; let mut sig_len = crate::buffer_len_to_u32(sig.len())?; @@ -869,7 +869,7 @@ impl Dilithium { msg.as_ptr(), msg_len, sig.as_mut_ptr(), &mut sig_len, &mut self.ws_key, - &mut rng.wc_rng, + rng.wc_rng, ) }; if rc != 0 { @@ -917,7 +917,7 @@ impl Dilithium { ctx: &[u8], msg: &[u8], sig: &mut [u8], - rng: &mut RNG, + rng: &RNG, ) -> Result { if ctx.len() > 255 { return Err(sys::wolfCrypt_ErrorCodes_BUFFER_E); @@ -931,7 +931,7 @@ impl Dilithium { msg.as_ptr(), msg_len, sig.as_mut_ptr(), &mut sig_len, &mut self.ws_key, - &mut rng.wc_rng, + rng.wc_rng, ) }; if rc != 0 { @@ -966,7 +966,7 @@ impl Dilithium { hash_alg: i32, hash: &[u8], sig: &mut [u8], - rng: &mut RNG, + rng: &RNG, ) -> Result { if ctx.len() > 255 { return Err(sys::wolfCrypt_ErrorCodes_BUFFER_E); @@ -981,7 +981,7 @@ impl Dilithium { hash.as_ptr(), hash_len, sig.as_mut_ptr(), &mut sig_len, &mut self.ws_key, - &mut rng.wc_rng, + rng.wc_rng, ) }; if rc != 0 { diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/ecc.rs b/wrapper/rust/wolfssl-wolfcrypt/src/ecc.rs index d154654a96..9f59105c36 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/ecc.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/ecc.rs @@ -30,7 +30,7 @@ wolfSSL `ecc_key` object. It ensures proper initialization and deallocation. use crate::sys; #[cfg(random)] -use crate::random::RNG; +use crate::random::{RNG, RngHandle}; use core::mem::{MaybeUninit}; /// Rust wrapper for wolfSSL `ecc_point` object. @@ -297,6 +297,10 @@ impl Drop for ECCPoint { /// `import_raw()`, or `import_raw_ex()`. pub struct ECC { pub(crate) wc_ecc_key: sys::ecc_key, + /// RNG bound to this key via `set_rng`, kept alive for as long as the C + /// struct holds its pointer. + #[cfg(random)] + rng: Option, } #[cfg(ecc_curve_ids)] @@ -421,7 +425,7 @@ impl ECC { /// } /// ``` #[cfg(random)] - pub fn generate(size: i32, rng: &mut RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { + pub fn generate(size: i32, rng: &RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { let mut wc_ecc_key: MaybeUninit = MaybeUninit::uninit(); let heap = match heap { Some(heap) => heap, @@ -436,9 +440,13 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let rc = unsafe { - sys::wc_ecc_make_key(&mut rng.wc_rng, size, &mut ecc.wc_ecc_key) + sys::wc_ecc_make_key(rng.wc_rng, size, &mut ecc.wc_ecc_key) }; if rc != 0 { return Err(rc); @@ -478,7 +486,7 @@ impl ECC { /// } /// ``` #[cfg(random)] - pub fn generate_ex(size: i32, rng: &mut RNG, curve_id: i32, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { + pub fn generate_ex(size: i32, rng: &RNG, curve_id: i32, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { let mut wc_ecc_key: MaybeUninit = MaybeUninit::uninit(); let heap = match heap { Some(heap) => heap, @@ -493,9 +501,13 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let rc = unsafe { - sys::wc_ecc_make_key_ex(&mut rng.wc_rng, size, &mut ecc.wc_ecc_key, curve_id) + sys::wc_ecc_make_key_ex(rng.wc_rng, size, &mut ecc.wc_ecc_key, curve_id) }; if rc != 0 { return Err(rc); @@ -536,7 +548,7 @@ impl ECC { /// } /// ``` #[cfg(random)] - pub fn generate_ex2(size: i32, rng: &mut RNG, curve_id: i32, flags: i32, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { + pub fn generate_ex2(size: i32, rng: &RNG, curve_id: i32, flags: i32, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { let mut wc_ecc_key: MaybeUninit = MaybeUninit::uninit(); let heap = match heap { Some(heap) => heap, @@ -551,9 +563,13 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let rc = unsafe { - sys::wc_ecc_make_key_ex2(&mut rng.wc_rng, size, &mut ecc.wc_ecc_key, curve_id, flags) + sys::wc_ecc_make_key_ex2(rng.wc_rng, size, &mut ecc.wc_ecc_key, curve_id, flags) }; if rc != 0 { return Err(rc); @@ -638,7 +654,11 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let mut idx = 0u32; let der_size = crate::buffer_len_to_u32(der.len())?; let rc = unsafe { @@ -701,7 +721,11 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let mut idx = 0u32; let der_size = crate::buffer_len_to_u32(der.len())?; let rc = unsafe { @@ -770,7 +794,11 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let priv_size = crate::buffer_len_to_u32(priv_buf.len())?; let pub_ptr = if pub_buf.is_empty() {core::ptr::null()} else {pub_buf.as_ptr()}; let pub_size = crate::buffer_len_to_u32(pub_buf.len())?; @@ -844,7 +872,11 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let priv_size = crate::buffer_len_to_u32(priv_buf.len())?; let pub_ptr = if pub_buf.is_empty() {core::ptr::null()} else {pub_buf.as_ptr()}; let pub_size = crate::buffer_len_to_u32(pub_buf.len())?; @@ -903,7 +935,11 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let qx_ptr = qx.as_ptr() as *const core::ffi::c_char; let qy_ptr = qy.as_ptr() as *const core::ffi::c_char; let d_ptr = d.as_ptr() as *const core::ffi::c_char; @@ -963,7 +999,11 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let qx_ptr = qx.as_ptr() as *const core::ffi::c_char; let qy_ptr = qy.as_ptr() as *const core::ffi::c_char; let d_ptr = d.as_ptr() as *const core::ffi::c_char; @@ -1031,7 +1071,11 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let rc = unsafe { sys::wc_ecc_import_unsigned(&mut ecc.wc_ecc_key, qx.as_ptr(), qy.as_ptr(), d.as_ptr(), curve_id) @@ -1090,7 +1134,11 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let rc = unsafe { sys::wc_ecc_import_x963(din.as_ptr(), din_size, &mut ecc.wc_ecc_key) }; @@ -1153,7 +1201,11 @@ impl ECC { return Err(rc); } let wc_ecc_key = unsafe { wc_ecc_key.assume_init() }; - let mut ecc = ECC { wc_ecc_key }; + let mut ecc = ECC { + wc_ecc_key, + #[cfg(random)] + rng: None, + }; let rc = unsafe { sys::wc_ecc_import_x963_ex(din.as_ptr(), din_size, &mut ecc.wc_ecc_key, curve_id) }; @@ -1674,9 +1726,9 @@ impl ECC { /// } /// ``` #[cfg(random)] - pub fn make_pub(&mut self, rng: Option<&mut RNG>) -> Result<(), i32> { + pub fn make_pub(&mut self, rng: Option<&RNG>) -> Result<(), i32> { let rng_ptr = match rng { - Some(rng) => &mut rng.wc_rng, + Some(rng) => rng.wc_rng, None => core::ptr::null_mut(), }; let rc = unsafe { @@ -1718,9 +1770,9 @@ impl ECC { /// } /// ``` #[cfg(random)] - pub fn make_pub_to_point(&mut self, rng: Option<&mut RNG>, heap: Option<*mut core::ffi::c_void>) -> Result { + pub fn make_pub_to_point(&mut self, rng: Option<&RNG>, heap: Option<*mut core::ffi::c_void>) -> Result { let rng_ptr = match rng { - Some(rng) => &mut rng.wc_rng, + Some(rng) => rng.wc_rng, None => core::ptr::null_mut(), }; let heap = match heap { @@ -1749,8 +1801,14 @@ impl ECC { /// # Parameters /// /// * `rng`: The `RNG` struct instance to associate with this `ECC` - /// instance. The `RNG` struct should not be moved in memory after - /// calling this method. + /// instance. + /// + /// # Safety contract + /// + /// The caller must ensure that the `RNG` instance is not dropped before + /// this `ECC` instance. The `ECC` struct holds an internal pointer to the + /// `RNG`'s underlying `WC_RNG` context, and dropping the `RNG` first + /// would result in a dangling pointer. /// /// # Returns /// @@ -1765,22 +1823,51 @@ impl ECC { /// { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::ecc::ECC; - /// let mut rng = RNG::new().expect("Failed to create RNG"); - /// let mut ecc = ECC::generate(32, &mut rng, None, None).expect("Error with generate()"); - /// ecc.set_rng(&mut rng).expect("Error with set_rng()"); + /// let blinding_rng = RNG::new().expect("Failed to create RNG"); + /// let key_gen_rng = RNG::new().expect("Failed to create RNG"); + /// let mut ecc = ECC::generate(32, &key_gen_rng, None, None).expect("Error with generate()"); + /// ecc.set_rng(blinding_rng).expect("Error with set_rng()"); /// } /// ``` #[cfg(random)] - pub fn set_rng(&mut self, rng: &mut RNG) -> Result<(), i32> { + pub fn set_rng(&mut self, rng: RNG) -> Result<(), i32> { + let wc_rng = rng.wc_rng; let rc = unsafe { - sys::wc_ecc_set_rng(&mut self.wc_ecc_key, &mut rng.wc_rng) + sys::wc_ecc_set_rng(&mut self.wc_ecc_key, wc_rng) }; if rc != 0 { return Err(rc); } + self.rng = Some(RngHandle::Owned(rng)); Ok(()) } + /// Bind a shared `RNG` to this key. Available when the `alloc` feature + /// is enabled. + #[cfg(all(random, feature = "alloc"))] + pub fn set_shared_rng(&mut self, rng: alloc::sync::Arc) -> Result<(), i32> { + let wc_rng = rng.wc_rng; + let rc = unsafe { + sys::wc_ecc_set_rng(&mut self.wc_ecc_key, wc_rng) + }; + if rc != 0 { + return Err(rc); + } + self.rng = Some(RngHandle::Shared(rng)); + Ok(()) + } + + /// Borrow the RNG previously bound via `set_rng` or `set_shared_rng`. + #[cfg(random)] + pub fn rng(&self) -> Option<&RNG> { + match &self.rng { + Some(RngHandle::Owned(rng)) => Some(rng), + #[cfg(feature = "alloc")] + Some(RngHandle::Shared(rng)) => Some(rng), + None => None, + } + } + /// Compute the ECDH shared secret using this key's private component /// and the peer public key. /// @@ -1797,17 +1884,18 @@ impl ECC { /// # Example /// /// ```rust - /// #[cfg(all(ecc_dh, random))] + /// #[cfg(all(ecc_dh, random, feature = "alloc"))] /// { + /// use std::sync::Arc; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::ecc::ECC; - /// let mut rng = RNG::new().expect("Failed to create RNG"); - /// let mut ecc0 = ECC::generate(32, &mut rng, None, None).expect("Error with generate()"); - /// let mut ecc1 = ECC::generate(32, &mut rng, None, None).expect("Error with generate()"); + /// let rng = Arc::new(RNG::new().expect("Failed to create RNG")); + /// let mut ecc0 = ECC::generate(32, &rng, None, None).expect("Error with generate()"); + /// let mut ecc1 = ECC::generate(32, &rng, None, None).expect("Error with generate()"); /// let mut ss0 = [0u8; 128]; /// let mut ss1 = [0u8; 128]; - /// ecc0.set_rng(&mut rng).expect("Error with set_rng()"); - /// ecc1.set_rng(&mut rng).expect("Error with set_rng()"); + /// ecc0.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); + /// ecc1.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let ss0_size = ecc0.shared_secret(&mut ecc1, &mut ss0).expect("Error with shared_secret()"); /// let ss1_size = ecc1.shared_secret(&mut ecc0, &mut ss1).expect("Error with shared_secret()"); /// assert_eq!(ss0_size, ss1_size); @@ -1846,18 +1934,19 @@ impl ECC { /// # Example /// /// ```rust - /// #[cfg(all(ecc_dh, random))] + /// #[cfg(all(ecc_dh, random, feature = "alloc"))] /// { + /// use std::sync::Arc; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::ecc::ECC; - /// let mut rng = RNG::new().expect("Failed to create RNG"); - /// let mut ecc0 = ECC::generate(32, &mut rng, None, None).expect("Error with generate()"); - /// let mut ecc1 = ECC::generate(32, &mut rng, None, None).expect("Error with generate()"); + /// let rng = Arc::new(RNG::new().expect("Failed to create RNG")); + /// let mut ecc0 = ECC::generate(32, &rng, None, None).expect("Error with generate()"); + /// let mut ecc1 = ECC::generate(32, &rng, None, None).expect("Error with generate()"); /// let ecc1_point = ecc1.make_pub_to_point(None, None).expect("Error with make_pub_to_point()"); /// let mut ss0 = [0u8; 128]; /// let mut ss1 = [0u8; 128]; - /// ecc0.set_rng(&mut rng).expect("Error with set_rng()"); - /// ecc1.set_rng(&mut rng).expect("Error with set_rng()"); + /// ecc0.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); + /// ecc1.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let ss0_size = ecc0.shared_secret_ex(&ecc1_point, &mut ss0).expect("Error with shared_secret_ex()"); /// let ss1_size = ecc1.shared_secret(&mut ecc0, &mut ss1).expect("Error with shared_secret()"); /// assert_eq!(ss0_size, ss1_size); @@ -1910,12 +1999,12 @@ impl ECC { /// } /// ``` #[cfg(all(ecc_sign, random))] - pub fn sign_hash(&mut self, din: &[u8], dout: &mut [u8], rng: &mut RNG) -> Result { + pub fn sign_hash(&mut self, din: &[u8], dout: &mut [u8], rng: &RNG) -> Result { let din_size = crate::buffer_len_to_u32(din.len())?; let mut dout_size = crate::buffer_len_to_u32(dout.len())?; let rc = unsafe { sys::wc_ecc_sign_hash(din.as_ptr(), din_size, dout.as_mut_ptr(), - &mut dout_size, &mut rng.wc_rng, &mut self.wc_ecc_key) + &mut dout_size, rng.wc_rng, &mut self.wc_ecc_key) }; if rc != 0 { return Err(rc); diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/ecdsa.rs b/wrapper/rust/wolfssl-wolfcrypt/src/ecdsa.rs index b771694b05..6e990df0aa 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/ecdsa.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/ecdsa.rs @@ -254,7 +254,7 @@ macro_rules! define_ecdsa_curve { der.as_mut_ptr(), &mut der_len, &mut self.inner.wc_ecc_key as *mut _ as *mut c_void, size_of::() as u32, - &mut self.rng.wc_rng, + self.rng.wc_rng, ) }; if rc != 0 { diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/ed25519.rs b/wrapper/rust/wolfssl-wolfcrypt/src/ed25519.rs index a0b91f1e27..5a950e729c 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/ed25519.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/ed25519.rs @@ -72,7 +72,7 @@ impl Ed25519 { /// let mut rng = RNG::new().expect("Error creating RNG"); /// let ed = Ed25519::generate(&mut rng).expect("Error with generate()"); /// ``` - pub fn generate(rng: &mut RNG) -> Result { + pub fn generate(rng: &RNG) -> Result { Self::generate_ex(rng, None, None) } @@ -97,7 +97,7 @@ impl Ed25519 { /// let mut rng = RNG::new().expect("Error creating RNG"); /// let ed = Ed25519::generate_ex(&mut rng, None, None).expect("Error with generate_ex()"); /// ``` - pub fn generate_ex(rng: &mut RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { + pub fn generate_ex(rng: &RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { let mut ws_key: MaybeUninit = MaybeUninit::uninit(); let heap = match heap { Some(heap) => heap, @@ -114,7 +114,7 @@ impl Ed25519 { let ws_key = unsafe { ws_key.assume_init() }; let mut ed25519 = Ed25519 { ws_key }; let rc = unsafe { - sys::wc_ed25519_make_key(&mut rng.wc_rng, + sys::wc_ed25519_make_key(rng.wc_rng, sys::ED25519_KEY_SIZE as i32, &mut ed25519.ws_key) }; if rc != 0 { diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/ed448.rs b/wrapper/rust/wolfssl-wolfcrypt/src/ed448.rs index 0a77aa071e..bb134ce796 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/ed448.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/ed448.rs @@ -71,7 +71,7 @@ impl Ed448 { /// let mut rng = RNG::new().expect("Error creating RNG"); /// let ed = Ed448::generate(&mut rng).expect("Error with generate()"); /// ``` - pub fn generate(rng: &mut RNG) -> Result { + pub fn generate(rng: &RNG) -> Result { Self::generate_ex(rng, None, None) } @@ -96,7 +96,7 @@ impl Ed448 { /// let mut rng = RNG::new().expect("Error creating RNG"); /// let ed = Ed448::generate_ex(&mut rng, None, None).expect("Error with generate_ex()"); /// ``` - pub fn generate_ex(rng: &mut RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { + pub fn generate_ex(rng: &RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { let mut ws_key: MaybeUninit = MaybeUninit::uninit(); let heap = match heap { Some(heap) => heap, @@ -113,7 +113,7 @@ impl Ed448 { let ws_key = unsafe { ws_key.assume_init() }; let mut ed448 = Ed448 { ws_key }; let rc = unsafe { - sys::wc_ed448_make_key(&mut rng.wc_rng, + sys::wc_ed448_make_key(rng.wc_rng, sys::ED448_KEY_SIZE as i32, &mut ed448.ws_key) }; if rc != 0 { diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/lib.rs b/wrapper/rust/wolfssl-wolfcrypt/src/lib.rs index 1c6a23dcdd..726fcbe5b1 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/lib.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/lib.rs @@ -20,6 +20,9 @@ #![no_std] +#[cfg(feature = "alloc")] +extern crate alloc; + /* bindgen-generated bindings to the C library */ pub mod sys; diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/lms.rs b/wrapper/rust/wolfssl-wolfcrypt/src/lms.rs index 90d4f1127f..4eca40433c 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/lms.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/lms.rs @@ -442,8 +442,8 @@ impl Lms { /// } /// ``` #[cfg(all(lms_make_key, random))] - pub fn make_key(&mut self, rng: &mut RNG) -> Result<(), i32> { - let rc = unsafe { sys::wc_LmsKey_MakeKey(&mut self.ws_key, &mut rng.wc_rng) }; + pub fn make_key(&mut self, rng: &RNG) -> Result<(), i32> { + let rc = unsafe { sys::wc_LmsKey_MakeKey(&mut self.ws_key, rng.wc_rng) }; if rc != 0 { return Err(rc); } diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/mlkem.rs b/wrapper/rust/wolfssl-wolfcrypt/src/mlkem.rs index e63709dea5..b2711fc996 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/mlkem.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/mlkem.rs @@ -123,7 +123,7 @@ impl MlKem { /// } /// ``` #[cfg(random)] - pub fn generate(key_type: i32, rng: &mut RNG) -> Result { + pub fn generate(key_type: i32, rng: &RNG) -> Result { Self::generate_ex(key_type, rng, None, None) } @@ -157,12 +157,12 @@ impl MlKem { #[cfg(random)] pub fn generate_ex( key_type: i32, - rng: &mut RNG, + rng: &RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option, ) -> Result { let key = Self::new_ex(key_type, heap, dev_id)?; - let rc = unsafe { sys::wc_MlKemKey_MakeKey(key.ws_key, &mut rng.wc_rng) }; + let rc = unsafe { sys::wc_MlKemKey_MakeKey(key.ws_key, rng.wc_rng) }; if rc != 0 { return Err(rc); } @@ -472,7 +472,7 @@ impl MlKem { &mut self, ct: &mut [u8], ss: &mut [u8], - rng: &mut RNG, + rng: &RNG, ) -> Result<(), i32> { // Verify the cipher text length is as expected based on the parameter // set (key type) in use. @@ -489,7 +489,7 @@ impl MlKem { self.ws_key, ct.as_mut_ptr(), ss.as_mut_ptr(), - &mut rng.wc_rng, + rng.wc_rng, ) }; if rc != 0 { diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/random.rs b/wrapper/rust/wolfssl-wolfcrypt/src/random.rs index a5904daa40..47319cf769 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/random.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/random.rs @@ -31,7 +31,7 @@ wolfSSL `WC_RNG` object. It ensures proper initialization and deallocation. use wolfssl_wolfcrypt::random::RNG; // Create a RNG instance. -let mut rng = RNG::new().expect("Failed to create RNG"); +let rng = RNG::new().expect("Failed to create RNG"); // Generate a single random byte value. let byte = rng.generate_byte().expect("Failed to generate a single byte"); @@ -45,18 +45,42 @@ rng.generate_block(&mut buffer).expect("Failed to generate a block"); #![cfg(random)] use crate::sys; -use core::mem::{size_of_val, MaybeUninit}; +use core::mem::size_of_val; use num_traits::PrimInt; /// A cryptographically secure random number generator based on the wolfSSL /// library. /// -/// This struct wraps the wolfssl `WC_RNG` type, providing a high-level API -/// for generating random bytes and blocks of data. The `Drop` implementation -/// ensures that the underlying wolfSSL RNG context is correctly freed when the -/// `RNG` struct goes out of scope, preventing memory leaks. +/// This struct wraps a pointer to a wolfssl `WC_RNG` allocated on the C heap, +/// providing a high-level API for generating random bytes and blocks of data. +/// The `Drop` implementation ensures that the underlying wolfSSL RNG context is +/// correctly freed when the `RNG` struct goes out of scope, preventing memory +/// leaks. +/// +/// All generation methods take `&self`. The actual mutation of the DRBG state +/// happens through the raw pointer in the C library; the `RNG` struct itself +/// is logically immutable after construction. pub struct RNG { - pub(crate) wc_rng: sys::WC_RNG, + pub(crate) wc_rng: *mut sys::WC_RNG, +} + +// Safety: the only field of `RNG` is a non-null pointer to a `WC_RNG` that +// lives on the C heap and is never reassigned after construction. Moving the +// struct between threads is sound. +unsafe impl Send for RNG {} + +// Note: `RNG` is intentionally not `Sync`. The underlying C `WC_RNG` state is +// mutated by every call to a generation routine, with no internal locking. +// Callers that need cross-thread sharing must wrap the RNG in a `Mutex` +// (typically `Arc>`). + +/// Storage for an RNG that a consumer (e.g. `RSA`, `ECC`) has been bound to +/// via `set_rng`. The consumer keeps the `RngHandle` alive for as long as the +/// C struct holds its pointer, ensuring the `WC_RNG` outlives the consumer. +pub(crate) enum RngHandle { + Owned(RNG), + #[cfg(feature = "alloc")] + Shared(alloc::sync::Arc), } impl RNG { @@ -97,7 +121,7 @@ impl RNG { return Err(rc); } } - let mut wc_rng: MaybeUninit = MaybeUninit::uninit(); + let mut wc_rng: *mut sys::WC_RNG = core::ptr::null_mut(); let heap = match heap { Some(heap) => heap, None => core::ptr::null_mut(), @@ -107,12 +131,10 @@ impl RNG { None => sys::INVALID_DEVID, }; let rc = unsafe { - sys::wc_InitRng_ex(wc_rng.as_mut_ptr(), heap, dev_id) + sys::wc_rng_new_ex(&mut wc_rng, core::ptr::null_mut(), 0, heap, dev_id) }; if rc == 0 { - let wc_rng = unsafe { wc_rng.assume_init() }; - let rng = RNG {wc_rng}; - Ok(rng) + Ok(RNG {wc_rng}) } else { Err(rc) } @@ -159,7 +181,7 @@ impl RNG { } let ptr = nonce.as_mut_ptr() as *mut u8; let size = crate::buffer_len_to_u32(size_of_val(nonce))?; - let mut wc_rng: MaybeUninit = MaybeUninit::uninit(); + let mut wc_rng: *mut sys::WC_RNG = core::ptr::null_mut(); let heap = match heap { Some(heap) => heap, None => core::ptr::null_mut(), @@ -169,12 +191,10 @@ impl RNG { None => sys::INVALID_DEVID, }; let rc = unsafe { - sys::wc_InitRngNonce_ex(wc_rng.as_mut_ptr(), ptr, size, heap, dev_id) + sys::wc_rng_new_ex(&mut wc_rng, ptr, size, heap, dev_id) }; if rc == 0 { - let wc_rng = unsafe { wc_rng.assume_init() }; - let rng = RNG {wc_rng}; - Ok(rng) + Ok(RNG {wc_rng}) } else { Err(rc) } @@ -315,9 +335,9 @@ impl RNG { /// /// A `Result` which is `Ok(u8)` containing the random byte on success or /// an `Err` with the wolfssl library return code on failure. - pub fn generate_byte(&mut self) -> Result { + pub fn generate_byte(&self) -> Result { let mut b: u8 = 0; - let rc = unsafe { sys::wc_RNG_GenerateByte(&mut self.wc_rng, &mut b) }; + let rc = unsafe { sys::wc_RNG_GenerateByte(self.wc_rng, &mut b) }; if rc == 0 { Ok(b) } else { @@ -339,10 +359,10 @@ impl RNG { /// /// A `Result` which is `Ok(())` on success or an `Err` with the wolfssl /// library return code on failure. - pub fn generate_block(&mut self, buf: &mut [T]) -> Result<(), i32> { + pub fn generate_block(&self, buf: &mut [T]) -> Result<(), i32> { let ptr = buf.as_mut_ptr() as *mut u8; let size = crate::buffer_len_to_u32(size_of_val(buf))?; - let rc = unsafe { sys::wc_RNG_GenerateBlock(&mut self.wc_rng, ptr, size) }; + let rc = unsafe { sys::wc_RNG_GenerateBlock(self.wc_rng, ptr, size) }; if rc == 0 { Ok(()) } else { @@ -371,10 +391,10 @@ impl RNG { /// rng.reseed(&seed).expect("Error with reseed()"); /// ``` #[cfg(random_hashdrbg)] - pub fn reseed(&mut self, seed: &[u8]) -> Result<(), i32> { + pub fn reseed(&self, seed: &[u8]) -> Result<(), i32> { let seed_size = crate::buffer_len_to_u32(seed.len())?; let rc = unsafe { - sys::wc_RNG_DRBG_Reseed(&mut self.wc_rng, seed.as_ptr(), seed_size) + sys::wc_RNG_DRBG_Reseed(self.wc_rng, seed.as_ptr(), seed_size) }; if rc != 0 { return Err(rc); @@ -411,22 +431,16 @@ impl rand_core::TryRng for RNG { #[cfg(feature = "rand_core")] impl rand_core::TryCryptoRng for RNG {} -impl RNG { - fn zeroize(&mut self) { - unsafe { crate::zeroize_raw(&mut self.wc_rng); } - } -} - impl Drop for RNG { /// Safely free the underlying wolfSSL RNG context. /// - /// This calls the `wc_FreeRng` wolfssl library function. + /// This calls the `wc_rng_free` wolfssl library function, which frees the + /// C-heap-allocated `WC_RNG` object. /// /// The Rust Drop trait guarantees that this method is called when the RNG /// struct goes out of scope, automatically cleaning up resources and /// preventing memory leaks. fn drop(&mut self) { - unsafe { sys::wc_FreeRng(&mut self.wc_rng); } - self.zeroize(); + unsafe { sys::wc_rng_free(self.wc_rng); } } } diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/rsa.rs b/wrapper/rust/wolfssl-wolfcrypt/src/rsa.rs index 3e89b79142..ff25bcffee 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/rsa.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/rsa.rs @@ -35,20 +35,20 @@ use std::fs; use wolfssl_wolfcrypt::random::RNG; use wolfssl_wolfcrypt::rsa::RSA; -let mut rng = RNG::new().expect("Error creating RNG"); +let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); let key_path = "../../../certs/client-keyPub.der"; let der: Vec = fs::read(key_path).expect("Error reading key file"); let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); -rsa.set_rng(&mut rng).expect("Error with set_rng()"); +rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); let plain: &[u8] = b"Test message"; let mut enc: [u8; 512] = [0; 512]; -let enc_len = rsa.public_encrypt(plain, &mut enc, &mut rng).expect("Error with public_encrypt()"); +let enc_len = rsa.public_encrypt(plain, &mut enc, &rng).expect("Error with public_encrypt()"); assert!(enc_len > 0 && enc_len <= 512); let key_path = "../../../certs/client-key.der"; let der: Vec = fs::read(key_path).expect("Error reading key file"); let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); -rsa.set_rng(&mut rng).expect("Error with set_rng()"); +rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); let mut plain_out: [u8; 512] = [0; 512]; let dec_len = rsa.private_decrypt(&enc[0..enc_len], &mut plain_out).expect("Error with private_decrypt()"); assert!(dec_len as usize == plain.len()); @@ -61,7 +61,7 @@ assert_eq!(plain_out[0..dec_len], *plain); use crate::sys; #[cfg(random)] -use crate::random::RNG; +use crate::random::{RNG, RngHandle}; use core::mem::{MaybeUninit}; /// The `RSA` struct manages the lifecycle of a wolfSSL `RsaKey` object. @@ -72,6 +72,10 @@ use core::mem::{MaybeUninit}; /// or `generate()`. pub struct RSA { pub(crate) wc_rsakey: sys::RsaKey, + /// RNG bound to this key via `set_rng`. Kept alive here so the C struct's + /// internal `WC_RNG` pointer remains valid for as long as the key exists. + #[cfg(random)] + rng: Option, } impl RSA { @@ -143,26 +147,26 @@ impl RSA { /// /// ```rust /// # extern crate std; - /// #[cfg(random)] + /// #[cfg(all(random, feature = "alloc"))] /// { /// use std::fs; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let plain: &[u8] = b"Test message"; /// let mut enc: [u8; 512] = [0; 512]; - /// let enc_len = rsa.public_encrypt(plain, &mut enc, &mut rng).expect("Error with public_encrypt()"); + /// let enc_len = rsa.public_encrypt(plain, &mut enc, &rng).expect("Error with public_encrypt()"); /// assert!(enc_len > 0 && enc_len <= 512); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let mut plain_out: [u8; 512] = [0; 512]; /// let dec_len = rsa.private_decrypt(&enc[0..enc_len], &mut plain_out).expect("Error with private_decrypt()"); /// assert!(dec_len as usize == plain.len()); @@ -191,26 +195,26 @@ impl RSA { /// /// ```rust /// # extern crate std; - /// #[cfg(random)] + /// #[cfg(all(random, feature = "alloc"))] /// { /// use std::fs; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let plain: &[u8] = b"Test message"; /// let mut enc: [u8; 512] = [0; 512]; - /// let enc_len = rsa.public_encrypt(plain, &mut enc, &mut rng).expect("Error with public_encrypt()"); + /// let enc_len = rsa.public_encrypt(plain, &mut enc, &rng).expect("Error with public_encrypt()"); /// assert!(enc_len > 0 && enc_len <= 512); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der_ex(&der, None, None).expect("Error with new_from_der_ex()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let mut plain_out: [u8; 512] = [0; 512]; /// let dec_len = rsa.private_decrypt(&enc[0..enc_len], &mut plain_out).expect("Error with private_decrypt()"); /// assert!(dec_len as usize == plain.len()); @@ -241,7 +245,11 @@ impl RSA { unsafe { sys::wc_FreeRsaKey(&mut wc_rsakey); } return Err(rc); } - let rsa = RSA { wc_rsakey }; + let rsa = RSA { + wc_rsakey, + #[cfg(random)] + rng: None, + }; Ok(rsa) } @@ -260,26 +268,26 @@ impl RSA { /// /// ```rust /// # extern crate std; - /// #[cfg(random)] + /// #[cfg(all(random, feature = "alloc"))] /// { /// use std::fs; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let plain: &[u8] = b"Test message"; /// let mut enc: [u8; 512] = [0; 512]; - /// let enc_len = rsa.public_encrypt(plain, &mut enc, &mut rng).expect("Error with public_encrypt()"); + /// let enc_len = rsa.public_encrypt(plain, &mut enc, &rng).expect("Error with public_encrypt()"); /// assert!(enc_len > 0 && enc_len <= 512); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let mut plain_out: [u8; 512] = [0; 512]; /// let dec_len = rsa.private_decrypt(&enc[0..enc_len], &mut plain_out).expect("Error with private_decrypt()"); /// assert!(dec_len as usize == plain.len()); @@ -308,26 +316,26 @@ impl RSA { /// /// ```rust /// # extern crate std; - /// #[cfg(random)] + /// #[cfg(all(random, feature = "alloc"))] /// { /// use std::fs; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der_ex(&der, None, None).expect("Error with new_public_from_der_ex()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let plain: &[u8] = b"Test message"; /// let mut enc: [u8; 512] = [0; 512]; - /// let enc_len = rsa.public_encrypt(plain, &mut enc, &mut rng).expect("Error with public_encrypt()"); + /// let enc_len = rsa.public_encrypt(plain, &mut enc, &rng).expect("Error with public_encrypt()"); /// assert!(enc_len > 0 && enc_len <= 512); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let mut plain_out: [u8; 512] = [0; 512]; /// let dec_len = rsa.private_decrypt(&enc[0..enc_len], &mut plain_out).expect("Error with private_decrypt()"); /// assert!(dec_len as usize == plain.len()); @@ -358,7 +366,11 @@ impl RSA { unsafe { sys::wc_FreeRsaKey(&mut wc_rsakey); } return Err(rc); } - let rsa = RSA { wc_rsakey }; + let rsa = RSA { + wc_rsakey, + #[cfg(random)] + rng: None, + }; Ok(rsa) } @@ -411,7 +423,11 @@ impl RSA { unsafe { sys::wc_FreeRsaKey(&mut wc_rsakey); } return Err(rc); } - Ok(RSA { wc_rsakey }) + Ok(RSA { + wc_rsakey, + #[cfg(random)] + rng: None, + }) } /// Generate a new RSA key using the given size and exponent. @@ -446,15 +462,15 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); - /// let mut rsa = RSA::generate(2048, 65537, &mut rng).expect("Error with generate()"); + /// let rng = RNG::new().expect("Error creating RNG"); + /// let mut rsa = RSA::generate(2048, 65537, &rng).expect("Error with generate()"); /// rsa.check().expect("Error with check()"); /// let encrypt_size = rsa.get_encrypt_size().expect("Error with get_encrypt_size()"); /// assert_eq!(encrypt_size, 256); /// } /// ``` #[cfg(all(random, rsa_keygen))] - pub fn generate(size: i32, e: i32, rng: &mut RNG) -> Result { + pub fn generate(size: i32, e: i32, rng: &RNG) -> Result { Self::generate_ex(size, e, rng, None, None) } @@ -493,15 +509,15 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); - /// let mut rsa = RSA::generate_ex(2048, 65537, &mut rng, None, None).expect("Error with generate_ex()"); + /// let rng = RNG::new().expect("Error creating RNG"); + /// let mut rsa = RSA::generate_ex(2048, 65537, &rng, None, None).expect("Error with generate_ex()"); /// rsa.check().expect("Error with check()"); /// let encrypt_size = rsa.get_encrypt_size().expect("Error with get_encrypt_size()"); /// assert_eq!(encrypt_size, 256); /// } /// ``` #[cfg(all(random, rsa_keygen))] - pub fn generate_ex(size: i32, e: i32, rng: &mut RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { + pub fn generate_ex(size: i32, e: i32, rng: &RNG, heap: Option<*mut core::ffi::c_void>, dev_id: Option) -> Result { let mut wc_rsakey: MaybeUninit = MaybeUninit::uninit(); let heap = match heap { Some(heap) => heap, @@ -518,13 +534,17 @@ impl RSA { let mut wc_rsakey = unsafe { wc_rsakey.assume_init() }; let e = e as core::ffi::c_long; let rc = unsafe { - sys::wc_MakeRsaKey(&mut wc_rsakey, size, e, &mut rng.wc_rng) + sys::wc_MakeRsaKey(&mut wc_rsakey, size, e, rng.wc_rng) }; if rc != 0 { unsafe { sys::wc_FreeRsaKey(&mut wc_rsakey); } return Err(rc); } - let rsa = RSA { wc_rsakey }; + let rsa = RSA { + wc_rsakey, + #[cfg(random)] + rng: None, + }; Ok(rsa) } @@ -556,8 +576,8 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); - /// let mut rsa = RSA::generate(2048, 65537, &mut rng).expect("Error with generate()"); + /// let rng = RNG::new().expect("Error creating RNG"); + /// let mut rsa = RSA::generate(2048, 65537, &rng).expect("Error with generate()"); /// let mut e: [u8; 256] = [0; 256]; /// let mut e_size: u32 = 0; /// let mut n: [u8; 256] = [0; 256]; @@ -624,8 +644,8 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); - /// let mut rsa = RSA::generate(2048, 65537, &mut rng).expect("Error with generate()"); + /// let rng = RNG::new().expect("Error creating RNG"); + /// let mut rsa = RSA::generate(2048, 65537, &rng).expect("Error with generate()"); /// let mut e: [u8; 256] = [0; 256]; /// let mut e_size: u32 = 0; /// let mut n: [u8; 256] = [0; 256]; @@ -667,8 +687,8 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); - /// let mut rsa = RSA::generate(2048, 65537, &mut rng).expect("Error with generate()"); + /// let rng = RNG::new().expect("Error creating RNG"); + /// let mut rsa = RSA::generate(2048, 65537, &rng).expect("Error with generate()"); /// let encrypt_size = rsa.get_encrypt_size().expect("Error with get_encrypt_size()"); /// assert_eq!(encrypt_size, 256); /// } @@ -696,8 +716,8 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); - /// let mut rsa = RSA::generate(2048, 65537, &mut rng).expect("Error with generate()"); + /// let rng = RNG::new().expect("Error creating RNG"); + /// let mut rsa = RSA::generate(2048, 65537, &rng).expect("Error with generate()"); /// rsa.check().expect("Error with check()"); /// } /// ``` @@ -729,26 +749,26 @@ impl RSA { /// /// ```rust /// # extern crate std; - /// #[cfg(random)] + /// #[cfg(all(random, feature = "alloc"))] /// { /// use std::fs; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let plain: &[u8] = b"Test message"; /// let mut enc: [u8; 512] = [0; 512]; - /// let enc_len = rsa.public_encrypt(plain, &mut enc, &mut rng).expect("Error with public_encrypt()"); + /// let enc_len = rsa.public_encrypt(plain, &mut enc, &rng).expect("Error with public_encrypt()"); /// assert!(enc_len > 0 && enc_len <= 512); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let mut plain_out: [u8; 512] = [0; 512]; /// let dec_len = rsa.private_decrypt(&enc[0..enc_len], &mut plain_out).expect("Error with private_decrypt()"); /// assert!(dec_len as usize == plain.len()); @@ -756,13 +776,13 @@ impl RSA { /// } /// ``` #[cfg(random)] - pub fn public_encrypt(&mut self, din: &[u8], dout: &mut [u8], rng: &mut RNG) -> Result { + pub fn public_encrypt(&mut self, din: &[u8], dout: &mut [u8], rng: &RNG) -> Result { let din_size = crate::buffer_len_to_u32(din.len())?; let dout_size = crate::buffer_len_to_u32(dout.len())?; let rc = unsafe { sys::wc_RsaPublicEncrypt(din.as_ptr(), din_size, dout.as_mut_ptr(), dout_size, &mut self.wc_rsakey, - &mut rng.wc_rng) + rng.wc_rng) }; if rc < 0 { return Err(rc); @@ -788,26 +808,26 @@ impl RSA { /// /// ```rust /// # extern crate std; - /// #[cfg(random)] + /// #[cfg(all(random, feature = "alloc"))] /// { /// use std::fs; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let plain: &[u8] = b"Test message"; /// let mut enc: [u8; 512] = [0; 512]; - /// let enc_len = rsa.public_encrypt(plain, &mut enc, &mut rng).expect("Error with public_encrypt()"); + /// let enc_len = rsa.public_encrypt(plain, &mut enc, &rng).expect("Error with public_encrypt()"); /// assert!(enc_len > 0 && enc_len <= 512); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let mut plain_out: [u8; 512] = [0; 512]; /// let dec_len = rsa.private_decrypt(&enc[0..enc_len], &mut plain_out).expect("Error with private_decrypt()"); /// assert!(dec_len as usize == plain.len()); @@ -855,20 +875,20 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); /// let msg: &[u8] = b"This is the string to be signed!"; /// let mut signature: [u8; 512] = [0; 512]; - /// let sig_len = rsa.pss_sign(msg, &mut signature, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256, &mut rng).expect("Error with pss_sign()"); + /// let sig_len = rsa.pss_sign(msg, &mut signature, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256, &rng).expect("Error with pss_sign()"); /// assert!(sig_len > 0 && sig_len <= 512); /// /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let signature = &signature[0..sig_len]; /// let mut verify_out: [u8; 512] = [0; 512]; /// let verify_out_size = rsa.pss_verify(signature, &mut verify_out, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256).expect("Error with pss_verify()"); @@ -880,12 +900,12 @@ impl RSA { /// } /// ``` #[cfg(all(random, rsa_pss))] - pub fn pss_sign(&mut self, din: &[u8], dout: &mut [u8], hash_algo: u32, mgf: i32, rng: &mut RNG) -> Result { + pub fn pss_sign(&mut self, din: &[u8], dout: &mut [u8], hash_algo: u32, mgf: i32, rng: &RNG) -> Result { let din_size = crate::buffer_len_to_u32(din.len())?; let dout_size = crate::buffer_len_to_u32(dout.len())?; let rc = unsafe { sys::wc_RsaPSS_Sign(din.as_ptr(), din_size, dout.as_mut_ptr(), dout_size, - hash_algo, mgf, &mut self.wc_rsakey, &mut rng.wc_rng) + hash_algo, mgf, &mut self.wc_rsakey, rng.wc_rng) }; if rc < 0 { return Err(rc); @@ -918,20 +938,20 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); /// let msg: &[u8] = b"This is the string to be signed!"; /// let mut signature: [u8; 512] = [0; 512]; - /// let sig_len = rsa.pss_sign(msg, &mut signature, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256, &mut rng).expect("Error with pss_sign()"); + /// let sig_len = rsa.pss_sign(msg, &mut signature, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256, &rng).expect("Error with pss_sign()"); /// assert!(sig_len > 0 && sig_len <= 512); /// /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let signature = &signature[0..sig_len]; /// let mut verify_out: [u8; 512] = [0; 512]; /// let verify_out_size = rsa.pss_verify(signature, &mut verify_out, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256).expect("Error with pss_verify()"); @@ -984,20 +1004,20 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); /// let msg: &[u8] = b"This is the string to be signed!"; /// let mut signature: [u8; 512] = [0; 512]; - /// let sig_len = rsa.pss_sign(msg, &mut signature, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256, &mut rng).expect("Error with pss_sign()"); + /// let sig_len = rsa.pss_sign(msg, &mut signature, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256, &rng).expect("Error with pss_sign()"); /// assert!(sig_len > 0 && sig_len <= 512); /// /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let signature = &signature[0..sig_len]; /// let mut verify_out: [u8; 512] = [0; 512]; /// let verify_out_size = rsa.pss_verify(signature, &mut verify_out, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256).expect("Error with pss_verify()"); @@ -1055,20 +1075,20 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); /// let msg: &[u8] = b"This is the string to be signed!"; /// let mut signature: [u8; 512] = [0; 512]; - /// let sig_len = rsa.pss_sign(msg, &mut signature, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256, &mut rng).expect("Error with pss_sign()"); + /// let sig_len = rsa.pss_sign(msg, &mut signature, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256, &rng).expect("Error with pss_sign()"); /// assert!(sig_len > 0 && sig_len <= 512); /// /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let signature = &signature[0..sig_len]; /// let mut verify_out: [u8; 512] = [0; 512]; /// let verify_out_size = rsa.pss_verify(signature, &mut verify_out, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256).expect("Error with pss_verify()"); @@ -1125,7 +1145,7 @@ impl RSA { /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = RNG::new().expect("Error creating RNG"); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); @@ -1134,22 +1154,22 @@ impl RSA { /// let mut plain = [0u8; 256]; /// plain[..msg.len()].copy_from_slice(msg); /// let mut enc = [0u8; 256]; - /// let enc_len = rsa.rsa_direct(&plain, &mut enc, RSA::PRIVATE_ENCRYPT, &mut rng).expect("Error with rsa_direct()"); + /// let enc_len = rsa.rsa_direct(&plain, &mut enc, RSA::PRIVATE_ENCRYPT, &rng).expect("Error with rsa_direct()"); /// assert_eq!(enc_len, 256); /// let mut plain_out = [0u8; 256]; - /// let dec_len = rsa.rsa_direct(&enc, &mut plain_out, RSA::PUBLIC_DECRYPT, &mut rng).expect("Error with rsa_direct()"); + /// let dec_len = rsa.rsa_direct(&enc, &mut plain_out, RSA::PUBLIC_DECRYPT, &rng).expect("Error with rsa_direct()"); /// assert_eq!(dec_len, 256); /// assert_eq!(plain_out, plain); /// } /// ``` #[cfg(all(rsa_direct, rsa_const_api))] - pub fn rsa_direct(&mut self, din: &[u8], dout: &mut [u8], typ: i32, rng: &mut RNG) -> Result { + pub fn rsa_direct(&mut self, din: &[u8], dout: &mut [u8], typ: i32, rng: &RNG) -> Result { let din_size = crate::buffer_len_to_u32(din.len())?; let mut dout_size = crate::buffer_len_to_u32(dout.len())?; let rc = unsafe { sys::wc_RsaDirect(din.as_ptr(), din_size, dout.as_mut_ptr(), &mut dout_size, - &mut self.wc_rsakey, typ, &mut rng.wc_rng) + &mut self.wc_rsakey, typ, rng.wc_rng) }; if rc < 0 { return Err(rc); @@ -1165,8 +1185,14 @@ impl RSA { /// # Parameters /// /// * `rng`: The `RNG` struct instance to associate with this `RSA` - /// instance. The `RNG` struct should not be moved in memory after - /// calling this method. + /// instance. + /// + /// # Safety contract + /// + /// The caller must ensure that the `RNG` instance is not dropped before + /// this `RSA` instance. The `RSA` struct holds an internal pointer to the + /// `RNG`'s underlying `WC_RNG` context, and dropping the `RNG` first + /// would result in a dangling pointer. /// /// # Returns /// @@ -1177,25 +1203,25 @@ impl RSA { /// /// ```rust /// # extern crate std; - /// #[cfg(random)] + /// #[cfg(all(random, feature = "alloc"))] /// { /// use std::fs; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let plain: &[u8] = b"Test message"; /// let mut enc: [u8; 512] = [0; 512]; - /// let enc_len = rsa.public_encrypt(plain, &mut enc, &mut rng).expect("Error with public_encrypt()"); + /// let enc_len = rsa.public_encrypt(plain, &mut enc, &rng).expect("Error with public_encrypt()"); /// assert!(enc_len > 0 && enc_len <= 512); /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let mut plain_out: [u8; 512] = [0; 512]; /// let dec_len = rsa.private_decrypt(&enc[0..enc_len], &mut plain_out).expect("Error with private_decrypt()"); /// assert!(dec_len as usize == plain.len()); @@ -1203,16 +1229,47 @@ impl RSA { /// } /// ``` #[cfg(random)] - pub fn set_rng(&mut self, rng: &mut RNG) -> Result<(), i32> { + pub fn set_rng(&mut self, rng: RNG) -> Result<(), i32> { + let wc_rng = rng.wc_rng; let rc = unsafe { - sys::wc_RsaSetRNG(&mut self.wc_rsakey, &mut rng.wc_rng) + sys::wc_RsaSetRNG(&mut self.wc_rsakey, wc_rng) }; if rc != 0 { return Err(rc); } + self.rng = Some(RngHandle::Owned(rng)); Ok(()) } + /// Bind a shared `RNG` to this key for blinding during private operations. + /// + /// Like `set_rng`, but takes an `Arc` so the same RNG can be shared + /// among multiple consumers and used directly by the caller. Available + /// when the `alloc` feature is enabled. + #[cfg(all(random, feature = "alloc"))] + pub fn set_shared_rng(&mut self, rng: alloc::sync::Arc) -> Result<(), i32> { + let wc_rng = rng.wc_rng; + let rc = unsafe { + sys::wc_RsaSetRNG(&mut self.wc_rsakey, wc_rng) + }; + if rc != 0 { + return Err(rc); + } + self.rng = Some(RngHandle::Shared(rng)); + Ok(()) + } + + /// Borrow the RNG previously bound via `set_rng` or `set_shared_rng`. + #[cfg(random)] + pub fn rng(&self) -> Option<&RNG> { + match &self.rng { + Some(RngHandle::Owned(rng)) => Some(rng), + #[cfg(feature = "alloc")] + Some(RngHandle::Shared(rng)) => Some(rng), + None => None, + } + } + /// Sign the provided data with the private key. /// /// # Parameters @@ -1233,26 +1290,26 @@ impl RSA { /// /// ```rust /// # extern crate std; - /// #[cfg(random)] + /// #[cfg(all(random, feature = "alloc"))] /// { /// use std::fs; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); /// let msg: &[u8] = b"This is the string to be signed!"; /// let mut signature: [u8; 512] = [0; 512]; - /// let sig_len = rsa.ssl_sign(msg, &mut signature, &mut rng).expect("Error with ssl_sign()"); + /// let sig_len = rsa.ssl_sign(msg, &mut signature, &rng).expect("Error with ssl_sign()"); /// assert!(sig_len > 0 && sig_len <= 512); /// /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let signature = &signature[0..sig_len]; /// let mut verify_out: [u8; 512] = [0; 512]; /// let verify_out_size = rsa.ssl_verify(signature, &mut verify_out).expect("Error with ssl_verify()"); @@ -1260,13 +1317,13 @@ impl RSA { /// } /// ``` #[cfg(random)] - pub fn ssl_sign(&mut self, din: &[u8], dout: &mut [u8], rng: &mut RNG) -> Result { + pub fn ssl_sign(&mut self, din: &[u8], dout: &mut [u8], rng: &RNG) -> Result { let din_size = crate::buffer_len_to_u32(din.len())?; let dout_size = crate::buffer_len_to_u32(dout.len())?; let rc = unsafe { sys::wc_RsaSSL_Sign(din.as_ptr(), din_size, dout.as_mut_ptr(), dout_size, - &mut self.wc_rsakey, &mut rng.wc_rng) + &mut self.wc_rsakey, rng.wc_rng) }; if rc < 0 { return Err(rc); @@ -1295,26 +1352,26 @@ impl RSA { /// /// ```rust /// # extern crate std; - /// #[cfg(random)] + /// #[cfg(all(random, feature = "alloc"))] /// { /// use std::fs; /// use wolfssl_wolfcrypt::random::RNG; /// use wolfssl_wolfcrypt::rsa::RSA; /// - /// let mut rng = RNG::new().expect("Error creating RNG"); + /// let rng = std::sync::Arc::new(RNG::new().expect("Error creating RNG")); /// /// let key_path = "../../../certs/client-key.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); /// let msg: &[u8] = b"This is the string to be signed!"; /// let mut signature: [u8; 512] = [0; 512]; - /// let sig_len = rsa.ssl_sign(msg, &mut signature, &mut rng).expect("Error with ssl_sign()"); + /// let sig_len = rsa.ssl_sign(msg, &mut signature, &rng).expect("Error with ssl_sign()"); /// assert!(sig_len > 0 && sig_len <= 512); /// /// let key_path = "../../../certs/client-keyPub.der"; /// let der: Vec = fs::read(key_path).expect("Error reading key file"); /// let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - /// rsa.set_rng(&mut rng).expect("Error with set_rng()"); + /// rsa.set_shared_rng(std::sync::Arc::clone(&rng)).expect("Error with set_shared_rng()"); /// let signature = &signature[0..sig_len]; /// let mut verify_out: [u8; 512] = [0; 512]; /// let verify_out_size = rsa.ssl_verify(signature, &mut verify_out).expect("Error with ssl_verify()"); diff --git a/wrapper/rust/wolfssl-wolfcrypt/src/rsa_pkcs1v15.rs b/wrapper/rust/wolfssl-wolfcrypt/src/rsa_pkcs1v15.rs index 40314f9548..576f63e222 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/src/rsa_pkcs1v15.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/src/rsa_pkcs1v15.rs @@ -149,9 +149,9 @@ pub struct SigningKey { impl SigningKey { /// Generate a fresh `N * 8`-bit RSA key with public exponent 65537. #[cfg(rsa_keygen)] - pub fn generate(mut rng: RNG) -> Result { + pub fn generate(rng: RNG) -> Result { let bits: i32 = (N * 8).try_into().map_err(|_| sys::wolfCrypt_ErrorCodes_BAD_FUNC_ARG)?; - let rsa = RSA::generate(bits, 65537, &mut rng)?; + let rsa = RSA::generate(bits, 65537, &rng)?; Ok(Self { inner: rsa, rng, _hash: PhantomData }) } @@ -186,7 +186,7 @@ impl SignerMut> for SigningKey { sig.as_mut_ptr(), &mut sig_len, &mut self.inner.wc_rsakey as *mut _ as *mut c_void, size_of::() as u32, - &mut self.rng.wc_rng, + self.rng.wc_rng, ) }; if rc != 0 || sig_len as usize != N { diff --git a/wrapper/rust/wolfssl-wolfcrypt/tests/test_curve25519.rs b/wrapper/rust/wolfssl-wolfcrypt/tests/test_curve25519.rs index ce1e3cc39d..98806b69ca 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/tests/test_curve25519.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/tests/test_curve25519.rs @@ -1,5 +1,7 @@ #![cfg(all(curve25519, random))] +#[cfg(curve25519_blinding)] +use std::sync::Arc; use wolfssl_wolfcrypt::curve25519::*; use wolfssl_wolfcrypt::random::RNG; @@ -97,14 +99,17 @@ fn test_make_pub_blind() { #[test] fn test_shared_secret() { - let mut rng = RNG::new().expect("Error with new()"); - let mut key1 = Curve25519Key::generate(&mut rng).expect("Error with generate()"); - let mut key2 = Curve25519Key::generate(&mut rng).expect("Error with generate()"); + #[cfg(curve25519_blinding)] + let rng = Arc::new(RNG::new().expect("Error with new()")); + #[cfg(not(curve25519_blinding))] + let rng = RNG::new().expect("Error with new()"); + let mut key1 = Curve25519Key::generate(&rng).expect("Error with generate()"); + let mut key2 = Curve25519Key::generate(&rng).expect("Error with generate()"); #[cfg(curve25519_blinding)] - key1.set_rng(&mut rng).expect("Error with set_rng()"); + key1.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); #[cfg(curve25519_blinding)] - key2.set_rng(&mut rng).expect("Error with set_rng()"); + key2.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); let mut public_buffer = [0u8; Curve25519Key::KEYSIZE]; key1.export_public(&mut public_buffer).expect("Error with export_public()"); @@ -122,14 +127,17 @@ fn test_shared_secret() { #[test] fn test_shared_secret_ex() { - let mut rng = RNG::new().expect("Error with new()"); - let mut key1 = Curve25519Key::generate(&mut rng).expect("Error with generate()"); - let mut key2 = Curve25519Key::generate(&mut rng).expect("Error with generate()"); + #[cfg(curve25519_blinding)] + let rng = Arc::new(RNG::new().expect("Error with new()")); + #[cfg(not(curve25519_blinding))] + let rng = RNG::new().expect("Error with new()"); + let mut key1 = Curve25519Key::generate(&rng).expect("Error with generate()"); + let mut key2 = Curve25519Key::generate(&rng).expect("Error with generate()"); #[cfg(curve25519_blinding)] - key1.set_rng(&mut rng).expect("Error with set_rng()"); + key1.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); #[cfg(curve25519_blinding)] - key2.set_rng(&mut rng).expect("Error with set_rng()"); + key2.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); let mut public_buffer = [0u8; Curve25519Key::KEYSIZE]; key1.export_public(&mut public_buffer).expect("Error with export_public()"); diff --git a/wrapper/rust/wolfssl-wolfcrypt/tests/test_ecc.rs b/wrapper/rust/wolfssl-wolfcrypt/tests/test_ecc.rs index 30d2d5b3fc..365913a3a5 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/tests/test_ecc.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/tests/test_ecc.rs @@ -4,6 +4,8 @@ mod common; #[cfg(any(all(ecc_import, ecc_export, ecc_sign, ecc_verify, random), random))] use std::fs; +#[cfg(all(ecc_dh, random))] +use std::sync::Arc; use wolfssl_wolfcrypt::ecc::*; #[cfg(random)] use wolfssl_wolfcrypt::random::RNG; @@ -134,7 +136,7 @@ fn test_ecc_import_export_sign_verify() { let valid = ecc.verify_hash(&signature, &hash).expect("Error with verify_hash()"); assert_eq!(valid, false); - ecc.set_rng(&mut rng).expect("Error with set_rng()"); + ecc.set_rng(rng).expect("Error with set_rng()"); } #[test] @@ -142,13 +144,13 @@ fn test_ecc_import_export_sign_verify() { fn test_ecc_shared_secret() { common::setup(); - let mut rng = RNG::new().expect("Failed to create RNG"); - let mut ecc0 = ECC::generate(32, &mut rng, None, None).expect("Error with generate()"); - let mut ecc1 = ECC::generate(32, &mut rng, None, None).expect("Error with generate()"); + let rng = Arc::new(RNG::new().expect("Failed to create RNG")); + let mut ecc0 = ECC::generate(32, &rng, None, None).expect("Error with generate()"); + let mut ecc1 = ECC::generate(32, &rng, None, None).expect("Error with generate()"); let mut ss0 = [0u8; 128]; let mut ss1 = [0u8; 128]; - ecc0.set_rng(&mut rng).expect("Error with set_rng()"); - ecc1.set_rng(&mut rng).expect("Error with set_rng()"); + ecc0.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); + ecc1.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); let ss0_size = ecc0.shared_secret(&mut ecc1, &mut ss0).expect("Error with shared_secret()"); let ss1_size = ecc1.shared_secret(&mut ecc0, &mut ss1).expect("Error with shared_secret()"); assert_eq!(ss0_size, ss1_size); diff --git a/wrapper/rust/wolfssl-wolfcrypt/tests/test_random.rs b/wrapper/rust/wolfssl-wolfcrypt/tests/test_random.rs index bf7f1536a7..45977123ad 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/tests/test_random.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/tests/test_random.rs @@ -52,7 +52,7 @@ fn test_test_seed() { fn test_rng_generate_byte() { // Since a single 0x00 or 0xFF could occur occasionally, we'll combine four // bytes into a u32 and make sure they aren't all 0x00 or all 0xFF. - let mut rng = RNG::new().expect("Failed to create RNG"); + let rng = RNG::new().expect("Failed to create RNG"); let mut v: u32 = 0; for _i in 0..4 { let byte = rng.generate_byte().expect("Failed to generate a single byte"); @@ -65,7 +65,7 @@ fn test_rng_generate_byte() { // Test that generate_block works for a slice of u8. #[test] fn test_rng_generate_block_u8() { - let mut rng = RNG::new().expect("Failed to create RNG"); + let rng = RNG::new().expect("Failed to create RNG"); let mut buffer = [0u8; 32]; rng.generate_block(&mut buffer).expect("Failed to generate a block of bytes"); @@ -77,7 +77,7 @@ fn test_rng_generate_block_u8() { // Test that generate_block works for a slice of u32. #[test] fn test_rng_generate_block_u32() { - let mut rng = RNG::new().expect("Failed to create RNG"); + let rng = RNG::new().expect("Failed to create RNG"); let mut buffer = [0u32; 8]; rng.generate_block(&mut buffer).expect("Failed to generate a block of u32"); @@ -93,7 +93,7 @@ fn test_rng_generate_block_u32() { #[test] #[cfg(random_hashdrbg)] fn test_rng_reseed() { - let mut rng = RNG::new().expect("Failed to create RNG"); + let rng = RNG::new().expect("Failed to create RNG"); let seed = [1u8, 2, 3, 4]; rng.reseed(&seed).expect("Error with reseed()"); } diff --git a/wrapper/rust/wolfssl-wolfcrypt/tests/test_rsa.rs b/wrapper/rust/wolfssl-wolfcrypt/tests/test_rsa.rs index b8d19cfaec..aa0b48d577 100644 --- a/wrapper/rust/wolfssl-wolfcrypt/tests/test_rsa.rs +++ b/wrapper/rust/wolfssl-wolfcrypt/tests/test_rsa.rs @@ -5,6 +5,8 @@ mod common; #[cfg(any(all(sha256, random, rsa_pss), random, rsa_direct))] use std::fs; #[cfg(random)] +use std::sync::Arc; +#[cfg(random)] use wolfssl_wolfcrypt::random::RNG; #[cfg(any(random, rsa_direct, rsa_keygen))] use wolfssl_wolfcrypt::rsa::*; @@ -14,8 +16,8 @@ use wolfssl_wolfcrypt::rsa::*; fn test_rsa_generate() { common::setup(); - let mut rng = RNG::new().expect("Error creating RNG"); - let mut rsa = RSA::generate(2048, 65537, &mut rng).expect("Error with generate()"); + let rng = RNG::new().expect("Error creating RNG"); + let mut rsa = RSA::generate(2048, 65537, &rng).expect("Error with generate()"); rsa.check().expect("Error with check()"); let encrypt_size = rsa.get_encrypt_size().expect("Error with get_encrypt_size()"); @@ -58,20 +60,20 @@ fn test_rsa_generate() { #[test] #[cfg(random)] fn test_rsa_encrypt_decrypt() { - let mut rng = RNG::new().expect("Error creating RNG"); + let rng = Arc::new(RNG::new().expect("Error creating RNG")); let key_path = "../../../certs/client-keyPub.der"; let der: Vec = fs::read(key_path).expect("Error reading key file"); let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - rsa.set_rng(&mut rng).expect("Error with set_rng()"); + rsa.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); let plain: &[u8] = b"Test message"; let mut enc: [u8; 512] = [0; 512]; - let enc_len = rsa.public_encrypt(plain, &mut enc, &mut rng).expect("Error with public_encrypt()"); + let enc_len = rsa.public_encrypt(plain, &mut enc, &rng).expect("Error with public_encrypt()"); assert!(enc_len > 0 && enc_len <= 512); let key_path = "../../../certs/client-key.der"; let der: Vec = fs::read(key_path).expect("Error reading key file"); let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); - rsa.set_rng(&mut rng).expect("Error with set_rng()"); + rsa.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); let mut plain_out: [u8; 512] = [0; 512]; let dec_len = rsa.private_decrypt(&enc[0..enc_len], &mut plain_out).expect("Error with private_decrypt()"); assert!(dec_len as usize == plain.len()); @@ -81,20 +83,20 @@ fn test_rsa_encrypt_decrypt() { #[test] #[cfg(all(sha256, random, rsa_pss))] fn test_rsa_pss() { - let mut rng = RNG::new().expect("Error creating RNG"); + let rng = Arc::new(RNG::new().expect("Error creating RNG")); let key_path = "../../../certs/client-key.der"; let der: Vec = fs::read(key_path).expect("Error reading key file"); let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); let msg: &[u8] = b"This is the string to be signed!"; let mut signature: [u8; 512] = [0; 512]; - let sig_len = rsa.pss_sign(msg, &mut signature, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256, &mut rng).expect("Error with pss_sign()"); + let sig_len = rsa.pss_sign(msg, &mut signature, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256, &rng).expect("Error with pss_sign()"); assert!(sig_len > 0 && sig_len <= 512); let key_path = "../../../certs/client-keyPub.der"; let der: Vec = fs::read(key_path).expect("Error reading key file"); let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - rsa.set_rng(&mut rng).expect("Error with set_rng()"); + rsa.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); let signature = &signature[0..sig_len]; let mut verify_out: [u8; 512] = [0; 512]; let verify_out_size = rsa.pss_verify(signature, &mut verify_out, RSA::HASH_TYPE_SHA256, RSA::MGF1SHA256).expect("Error with pss_verify()"); @@ -108,7 +110,7 @@ fn test_rsa_pss() { #[test] #[cfg(rsa_direct)] fn test_rsa_direct() { - let mut rng = RNG::new().expect("Error creating RNG"); + let rng = RNG::new().expect("Error creating RNG"); let key_path = "../../../certs/client-key.der"; let der: Vec = fs::read(key_path).expect("Error reading key file"); @@ -117,10 +119,10 @@ fn test_rsa_direct() { let mut plain = [0u8; 256]; plain[..msg.len()].copy_from_slice(msg); let mut enc = [0u8; 256]; - let enc_len = rsa.rsa_direct(&plain, &mut enc, RSA::PRIVATE_ENCRYPT, &mut rng).expect("Error with rsa_direct()"); + let enc_len = rsa.rsa_direct(&plain, &mut enc, RSA::PRIVATE_ENCRYPT, &rng).expect("Error with rsa_direct()"); assert_eq!(enc_len, 256); let mut plain_out = [0u8; 256]; - let dec_len = rsa.rsa_direct(&enc, &mut plain_out, RSA::PUBLIC_DECRYPT, &mut rng).expect("Error with rsa_direct()"); + let dec_len = rsa.rsa_direct(&enc, &mut plain_out, RSA::PUBLIC_DECRYPT, &rng).expect("Error with rsa_direct()"); assert_eq!(dec_len, 256); assert_eq!(plain_out, plain); } @@ -128,20 +130,20 @@ fn test_rsa_direct() { #[test] #[cfg(random)] fn test_rsa_ssl() { - let mut rng = RNG::new().expect("Error creating RNG"); + let rng = Arc::new(RNG::new().expect("Error creating RNG")); let key_path = "../../../certs/client-key.der"; let der: Vec = fs::read(key_path).expect("Error reading key file"); let mut rsa = RSA::new_from_der(&der).expect("Error with new_from_der()"); let msg: &[u8] = b"This is the string to be signed!"; let mut signature: [u8; 512] = [0; 512]; - let sig_len = rsa.ssl_sign(msg, &mut signature, &mut rng).expect("Error with ssl_sign()"); + let sig_len = rsa.ssl_sign(msg, &mut signature, &rng).expect("Error with ssl_sign()"); assert!(sig_len > 0 && sig_len <= 512); let key_path = "../../../certs/client-keyPub.der"; let der: Vec = fs::read(key_path).expect("Error reading key file"); let mut rsa = RSA::new_public_from_der(&der).expect("Error with new_public_from_der()"); - rsa.set_rng(&mut rng).expect("Error with set_rng()"); + rsa.set_shared_rng(Arc::clone(&rng)).expect("Error with set_shared_rng()"); let signature = &signature[0..sig_len]; let mut verify_out: [u8; 512] = [0; 512]; let verify_out_size = rsa.ssl_verify(signature, &mut verify_out).expect("Error with ssl_verify()");