diff --git a/src/algorithms.rs b/src/algorithms.rs index 94eb363..0d2e07e 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -4,11 +4,16 @@ use serde::{Deserialize, Serialize}; use crate::errors::{Error, ErrorKind, Result}; +/// Supported families of algorithms. #[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] pub enum AlgorithmFamily { + /// HMAC shared secret family. Hmac, + /// RSA-based public key family. Rsa, + /// Edwards curve public key family. Ec, + /// Elliptic curve public key family. Ed, } @@ -88,7 +93,8 @@ impl FromStr for Algorithm { } impl Algorithm { - pub(crate) fn family(self) -> AlgorithmFamily { + /// The family of the algorithm. + pub fn family(self) -> AlgorithmFamily { match self { Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => AlgorithmFamily::Hmac, Algorithm::RS256 diff --git a/src/lib.rs b/src/lib.rs index 920b996..b15c62d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ compile_error!( #[cfg(not(any(feature = "rust_crypto", feature = "aws_lc_rs")))] compile_error!("at least one of the features \"rust_crypto\" or \"aws_lc_rs\" must be enabled"); -pub use algorithms::Algorithm; +pub use algorithms::{Algorithm, AlgorithmFamily}; pub use decoding::{DecodingKey, TokenData, decode, decode_header}; pub use encoding::{EncodingKey, encode}; pub use header::Header; diff --git a/src/validation.rs b/src/validation.rs index 96ae3ac..71a8583 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -6,7 +6,7 @@ use std::marker::PhantomData; use serde::de::{self, Visitor}; use serde::{Deserialize, Deserializer}; -use crate::algorithms::Algorithm; +use crate::algorithms::{Algorithm, AlgorithmFamily}; use crate::errors::{ErrorKind, Result, new_error}; /// Contains the various validations that are applied after decoding a JWT. @@ -111,12 +111,21 @@ pub struct Validation { impl Validation { /// Create a default validation setup allowing the given alg pub fn new(alg: Algorithm) -> Validation { + Self::new_impl(vec![alg]) + } + + /// Create a default validation setup allowing any algorithm in the family + pub fn new_for_family(family: AlgorithmFamily) -> Validation { + Self::new_impl(family.algorithms().into_iter().copied().collect()) + } + + fn new_impl(algorithms: Vec) -> Validation { let mut required_claims = HashSet::with_capacity(1); required_claims.insert("exp".to_owned()); Validation { required_spec_claims: required_claims, - algorithms: vec![alg], + algorithms, leeway: 60, reject_tokens_expiring_in_less_than: 0,