diff --git a/src/crypto/aws_lc/mod.rs b/src/crypto/aws_lc/mod.rs index 76f8a2c..8ab6b41 100644 --- a/src/crypto/aws_lc/mod.rs +++ b/src/crypto/aws_lc/mod.rs @@ -8,7 +8,7 @@ use aws_lc_rs::{ use crate::{ Algorithm, DecodingKey, EncodingKey, - crypto::{CryptoProvider, JwkUtils, JwtSigner, JwtVerifier}, + crypto::{CryptoProvider, JwtSigner, JwtVerifier}, errors::{self, Error, ErrorKind}, jwk::{EllipticCurve, ThumbprintHash}, }; @@ -18,7 +18,7 @@ mod eddsa; mod hmac; mod rsa; -fn extract_rsa_public_key_components(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { +fn rsa_components_from_private_key(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { let key_pair = aws_sig::RsaKeyPair::from_der(key_content) .map_err(|e| ErrorKind::InvalidRsaKey(e.to_string()))?; let public = key_pair.public_key(); @@ -26,7 +26,15 @@ fn extract_rsa_public_key_components(key_content: &[u8]) -> errors::Result<(Vec< Ok((components.n, components.e)) } -fn extract_ec_public_key_coordinates( +fn rsa_components_from_public_key(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { + let public = aws_lc_rs::rsa::PublicKey::from_der(key_content) + .map_err(|e| ErrorKind::InvalidRsaKey(e.to_string()))?; + + let components = aws_sig::RsaPublicKeyComponents::>::from(&public); + Ok((components.n, components.e)) +} + +fn ec_components_from_private_key( key_content: &[u8], alg: Algorithm, ) -> errors::Result<(EllipticCurve, Vec, Vec)> { @@ -102,9 +110,8 @@ fn new_verifier( pub static DEFAULT_PROVIDER: CryptoProvider = CryptoProvider { signer_factory: new_signer, verifier_factory: new_verifier, - jwk_utils: JwkUtils { - extract_rsa_public_key_components, - extract_ec_public_key_coordinates, - compute_digest, - }, + rsa_pub_components_from_private_key: rsa_components_from_private_key, + rsa_pub_components_from_public_key: rsa_components_from_public_key, + ec_pub_components_from_private_key: ec_components_from_private_key, + compute_digest, }; diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index d681bde..b0ba749 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -10,10 +10,15 @@ //! [`CryptoProvider`]: crate::crypto::CryptoProvider use crate::algorithms::Algorithm; -use crate::errors::Result; +use crate::errors::{self, ErrorKind, Result}; use crate::jwk::{EllipticCurve, ThumbprintHash}; use crate::{DecodingKey, EncodingKey}; +const NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR: &'static str = r###" +Could not automatically determine the process-level CryptoProvider from jsonwebtoken crate features, or your CryptoProvider does not support JWKs. +Call CryptoProvider::install_default() before this point to select a provider manually, or make sure exactly one of the 'rust_crypto' and 'aws_lc_rs' features is enabled. +See the documentation of the CryptoProvider type for more information. +"###; /// `aws_lc_rs` based CryptoProvider. #[cfg(feature = "aws_lc_rs")] pub mod aws_lc; @@ -85,8 +90,19 @@ pub struct CryptoProvider { pub signer_factory: fn(&Algorithm, &EncodingKey) -> Result>, /// A function that produces a [`JwtVerifier`] for a given [`Algorithm`] pub verifier_factory: fn(&Algorithm, &DecodingKey) -> Result>, - /// Struct with utility functions for JWK processing. - pub jwk_utils: JwkUtils, + /// Given a DER encoded private key, extract the RSA public key components (n, e) + #[allow(clippy::type_complexity)] + pub rsa_pub_components_from_private_key: fn(&[u8]) -> Result<(Vec, Vec)>, + /// Given a DER encoded public key, extract the RSA public key components (n, e) + #[allow(clippy::type_complexity)] + pub rsa_pub_components_from_public_key: fn(&[u8]) -> Result<(Vec, Vec)>, + /// Given a DER encoded private key and an algorithm, extract the associated curve + /// and the EC public key components (x, y) + #[allow(clippy::type_complexity)] + pub ec_pub_components_from_private_key: + fn(&[u8], Algorithm) -> Result<(EllipticCurve, Vec, Vec)>, + /// Given some data and a name of a hash function, compute hash_function(data) + pub compute_digest: fn(&[u8], ThumbprintHash) -> Vec, } impl CryptoProvider { @@ -123,7 +139,16 @@ See the documentation of the CryptoProvider type for more information. static INSTANCE: CryptoProvider = CryptoProvider { signer_factory: |_, _| panic!("{}", NOT_INSTALLED_ERROR), verifier_factory: |_, _| panic!("{}", NOT_INSTALLED_ERROR), - jwk_utils: JwkUtils::new_unimplemented(), + rsa_pub_components_from_private_key: |_| { + panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR) + }, + rsa_pub_components_from_public_key: |_| { + panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR) + }, + ec_pub_components_from_private_key: |_, _| { + panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR) + }, + compute_digest: |_, _| panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR), }; &INSTANCE @@ -131,41 +156,21 @@ See the documentation of the CryptoProvider type for more information. } } -/// Holds utility functions required for JWK processing. -/// Use the [`JwkUtils::new_unimplemented`] function to initialize all values to dummies. -#[derive(Clone, Debug)] -pub struct JwkUtils { - /// Given a DER encoded private key, extract the RSA public key components (n, e) - #[allow(clippy::type_complexity)] - pub extract_rsa_public_key_components: fn(&[u8]) -> Result<(Vec, Vec)>, - /// Given a DER encoded private key and an algorithm, extract the associated curve - /// and the EC public key components (x, y) - #[allow(clippy::type_complexity)] - pub extract_ec_public_key_coordinates: - fn(&[u8], Algorithm) -> Result<(EllipticCurve, Vec, Vec)>, - /// Given some data and a name of a hash function, compute hash_function(data) - pub compute_digest: fn(&[u8], ThumbprintHash) -> Vec, -} - -impl JwkUtils { - /// Initialises all values to dummies. - /// Will lead to a panic when JWKs are required, so only use it if you don't want to support JWKs. - pub const fn new_unimplemented() -> Self { - const NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR: &str = r###" -Could not automatically determine the process-level CryptoProvider from jsonwebtoken crate features, or your CryptoProvider does not support JWKs. -Call CryptoProvider::install_default() before this point to select a provider manually, or make sure exactly one of the 'rust_crypto' and 'aws_lc_rs' features is enabled. -See the documentation of the CryptoProvider type for more information. -"###; - Self { - extract_rsa_public_key_components: |_| { - panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR) - }, - extract_ec_public_key_coordinates: |_, _| { - panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR) - }, - compute_digest: |_, _| panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR), - } +pub(crate) fn ec_components_from_public_key( + pub_bytes: &[u8], +) -> errors::Result<(EllipticCurve, Vec, Vec)> { + let (curve, pub_elem_bytes) = match pub_bytes.len() { + 65 => (EllipticCurve::P256, 32), + 97 => (EllipticCurve::P384, 48), + _ => return Err(ErrorKind::InvalidEcdsaKey.into()), + }; + + if pub_bytes[0] != 4 { + return Err(ErrorKind::InvalidEcdsaKey.into()); } + + let (x, y) = pub_bytes[1..].split_at(pub_elem_bytes); + Ok((curve, x.to_vec(), y.to_vec())) } mod static_default { diff --git a/src/crypto/rust_crypto/mod.rs b/src/crypto/rust_crypto/mod.rs index cd0c9bd..9fa8483 100644 --- a/src/crypto/rust_crypto/mod.rs +++ b/src/crypto/rust_crypto/mod.rs @@ -1,11 +1,15 @@ -use ::rsa::{RsaPrivateKey, pkcs1::DecodeRsaPrivateKey, traits::PublicKeyParts}; +use ::rsa::{ + RsaPrivateKey, RsaPublicKey, + pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey}, + traits::PublicKeyParts, +}; use p256::{ecdsa::SigningKey as P256SigningKey, pkcs8::DecodePrivateKey}; use p384::ecdsa::SigningKey as P384SigningKey; use sha2::{Digest, Sha256, Sha384, Sha512}; use crate::{ Algorithm, DecodingKey, EncodingKey, - crypto::{CryptoProvider, JwkUtils, JwtSigner, JwtVerifier}, + crypto::{CryptoProvider, JwtSigner, JwtVerifier}, errors::{self, Error, ErrorKind}, jwk::{EllipticCurve, ThumbprintHash}, }; @@ -15,14 +19,20 @@ mod eddsa; mod hmac; mod rsa; -fn extract_rsa_public_key_components(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { +fn rsa_components_from_private_key(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { let private_key = RsaPrivateKey::from_pkcs1_der(key_content) .map_err(|e| ErrorKind::InvalidRsaKey(e.to_string()))?; let public_key = private_key.to_public_key(); Ok((public_key.n().to_bytes_be(), public_key.e().to_bytes_be())) } -fn extract_ec_public_key_coordinates( +fn rsa_components_from_public_key(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { + let public_key = RsaPublicKey::from_pkcs1_der(key_content) + .map_err(|e| ErrorKind::InvalidRsaKey(e.to_string()))?; + Ok((public_key.n().to_bytes_be(), public_key.e().to_bytes_be())) +} + +fn ec_components_from_private_key( key_content: &[u8], alg: Algorithm, ) -> errors::Result<(EllipticCurve, Vec, Vec)> { @@ -108,9 +118,8 @@ fn new_verifier( pub static DEFAULT_PROVIDER: CryptoProvider = CryptoProvider { signer_factory: new_signer, verifier_factory: new_verifier, - jwk_utils: JwkUtils { - extract_rsa_public_key_components, - extract_ec_public_key_coordinates, - compute_digest, - }, + rsa_pub_components_from_private_key: rsa_components_from_private_key, + rsa_pub_components_from_public_key: rsa_components_from_public_key, + ec_pub_components_from_private_key: ec_components_from_private_key, + compute_digest, }; diff --git a/src/jwk.rs b/src/jwk.rs index 200bd04..b2aa4dc 100644 --- a/src/jwk.rs +++ b/src/jwk.rs @@ -11,7 +11,8 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; use crate::crypto::CryptoProvider; use crate::serialization::b64_encode; use crate::{ - Algorithm, EncodingKey, + Algorithm, DecodingKey, EncodingKey, + decoding::DecodingKeyKind, errors::{self, Error, ErrorKind}, }; @@ -222,6 +223,25 @@ impl FromStr for KeyAlgorithm { } } +impl From for KeyAlgorithm { + fn from(algorithm: Algorithm) -> Self { + match algorithm { + Algorithm::HS256 => KeyAlgorithm::HS256, + Algorithm::HS384 => KeyAlgorithm::HS384, + Algorithm::HS512 => KeyAlgorithm::HS512, + Algorithm::ES256 => KeyAlgorithm::ES256, + Algorithm::ES384 => KeyAlgorithm::ES384, + Algorithm::RS256 => KeyAlgorithm::RS256, + Algorithm::RS384 => KeyAlgorithm::RS384, + Algorithm::RS512 => KeyAlgorithm::RS512, + Algorithm::PS256 => KeyAlgorithm::PS256, + Algorithm::PS384 => KeyAlgorithm::PS384, + Algorithm::PS512 => KeyAlgorithm::PS512, + Algorithm::EdDSA => KeyAlgorithm::EdDSA, + } + } +} + impl fmt::Display for KeyAlgorithm { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:?}", self) @@ -437,23 +457,7 @@ impl Jwk { } pub fn from_encoding_key(key: &EncodingKey, alg: Algorithm) -> crate::errors::Result { Ok(Self { - common: CommonParameters { - key_algorithm: Some(match alg { - Algorithm::HS256 => KeyAlgorithm::HS256, - Algorithm::HS384 => KeyAlgorithm::HS384, - Algorithm::HS512 => KeyAlgorithm::HS512, - Algorithm::ES256 => KeyAlgorithm::ES256, - Algorithm::ES384 => KeyAlgorithm::ES384, - Algorithm::RS256 => KeyAlgorithm::RS256, - Algorithm::RS384 => KeyAlgorithm::RS384, - Algorithm::RS512 => KeyAlgorithm::RS512, - Algorithm::PS256 => KeyAlgorithm::PS256, - Algorithm::PS384 => KeyAlgorithm::PS384, - Algorithm::PS512 => KeyAlgorithm::PS512, - Algorithm::EdDSA => KeyAlgorithm::EdDSA, - }), - ..Default::default() - }, + common: CommonParameters { key_algorithm: Some(alg.into()), ..Default::default() }, algorithm: match key.family() { crate::algorithms::AlgorithmFamily::Hmac => { AlgorithmParameters::OctetKey(OctetKeyParameters { @@ -463,8 +467,7 @@ impl Jwk { } crate::algorithms::AlgorithmFamily::Rsa => { let (n, e) = (CryptoProvider::get_default() - .jwk_utils - .extract_rsa_public_key_components)( + .rsa_pub_components_from_private_key)( key.inner() )?; AlgorithmParameters::RSA(RSAKeyParameters { @@ -475,8 +478,7 @@ impl Jwk { } crate::algorithms::AlgorithmFamily::Ec => { let (curve, x, y) = (CryptoProvider::get_default() - .jwk_utils - .extract_ec_public_key_coordinates)( + .ec_pub_components_from_private_key)( key.inner(), alg )?; AlgorithmParameters::EllipticCurve(EllipticCurveKeyParameters { @@ -493,6 +495,62 @@ impl Jwk { }) } + pub fn from_decoding_key( + key: &DecodingKey, + alg: Option, + ) -> crate::errors::Result { + Ok(Self { + common: CommonParameters { key_algorithm: alg.map(|a| a.into()), ..Default::default() }, + algorithm: match key.family() { + crate::algorithms::AlgorithmFamily::Hmac => { + let secret = match &key.kind() { + DecodingKeyKind::SecretOrDer(secret) => secret, + _ => return Err(ErrorKind::InvalidKeyFormat.into()), + }; + + AlgorithmParameters::OctetKey(OctetKeyParameters { + key_type: OctetKeyType::Octet, + value: b64_encode(secret), + }) + } + crate::algorithms::AlgorithmFamily::Rsa => { + let (n, e) = match &key.kind() { + DecodingKeyKind::RsaModulusExponent { n, e } => { + (b64_encode(n), b64_encode(e)) + } + DecodingKeyKind::SecretOrDer(der) => { + let (n, e) = (CryptoProvider::get_default() + .rsa_pub_components_from_public_key)( + der + )?; + (b64_encode(n), b64_encode(e)) + } + }; + + AlgorithmParameters::RSA(RSAKeyParameters { key_type: RSAKeyType::RSA, n, e }) + } + crate::algorithms::AlgorithmFamily::Ec => { + let (curve, x, y) = match &key.kind() { + DecodingKeyKind::SecretOrDer(pub_bytes) => { + crate::crypto::ec_components_from_public_key(pub_bytes)? + } + _ => return Err(ErrorKind::InvalidKeyFormat.into()), + }; + + AlgorithmParameters::EllipticCurve(EllipticCurveKeyParameters { + key_type: EllipticCurveKeyType::EC, + curve, + x: b64_encode(x), + y: b64_encode(y), + }) + } + crate::algorithms::AlgorithmFamily::Ed => { + unimplemented!(); + } + }, + }) + } + /// Compute the thumbprint of the JWK. /// /// Per [RFC-7638](https://datatracker.ietf.org/doc/html/rfc7638) @@ -540,10 +598,7 @@ impl Jwk { }, }; - b64_encode((CryptoProvider::get_default().jwk_utils.compute_digest)( - pre.as_bytes(), - hash_function, - )) + b64_encode((CryptoProvider::get_default().compute_digest)(pre.as_bytes(), hash_function)) } } @@ -573,6 +628,7 @@ mod tests { ThumbprintHash, }; use crate::serialization::b64_encode; + use crate::{DecodingKey, EncodingKey}; #[test] #[wasm_bindgen_test] @@ -627,4 +683,30 @@ mod tests { .thumbprint(ThumbprintHash::SHA256); assert_eq!(tp.as_str(), "NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs"); } + + #[test] + fn check_jwk_from_decoding_key_rsa() { + let enc_key = + EncodingKey::from_rsa_pem(include_bytes!("../tests/rsa/private_rsa_key_pkcs8.pem")) + .unwrap(); + let dec_key = + DecodingKey::from_rsa_pem(include_bytes!("../tests/rsa/public_rsa_key_pkcs8.pem")) + .unwrap(); + let expected_jwk = Jwk::from_encoding_key(&enc_key, Algorithm::RS256).unwrap(); + let jwk = Jwk::from_decoding_key(&dec_key, Some(Algorithm::RS256)).unwrap(); + assert_eq!(jwk, expected_jwk); + } + + #[test] + fn check_jwk_from_decoding_key_ec() { + let enc_key = + EncodingKey::from_ec_pem(include_bytes!("../tests/ecdsa/private_ecdsa_key.pem")) + .unwrap(); + let dec_key = + DecodingKey::from_ec_pem(include_bytes!("../tests/ecdsa/public_ecdsa_key.pem")) + .unwrap(); + let expected_jwk = Jwk::from_encoding_key(&enc_key, Algorithm::ES256).unwrap(); + let jwk = Jwk::from_decoding_key(&dec_key, Some(Algorithm::ES256)).unwrap(); + assert_eq!(jwk, expected_jwk); + } }