diff --git a/Cargo.lock b/Cargo.lock index f2806dd02..f722d181e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1281,7 +1281,7 @@ dependencies = [ "prost-types", "tonic", "tonic-prost", - "ureq", + "ureq 3.1.4", ] [[package]] @@ -4109,7 +4109,7 @@ dependencies = [ "pkg-config", "sha3", "tar", - "ureq", + "ureq 3.1.4", "windows-sys 0.61.2", ] @@ -4201,7 +4201,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.2", "tokio", "tower-service", "tracing", @@ -7117,7 +7117,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls", - "socket2 0.5.10", + "socket2 0.6.2", "thiserror 2.0.18", "tokio", "tracing", @@ -7156,7 +7156,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.2", "tracing", "windows-sys 0.60.2", ] @@ -8323,6 +8323,7 @@ dependencies = [ "rustls-pemfile", "send_wrapper", "serde", + "serde_json", "serde_with", "slab", "socket2 0.6.2", @@ -8340,6 +8341,7 @@ dependencies = [ "tracing-subscriber", "tungstenite", "ulid", + "ureq 2.12.1", "uuid", "vergen-git2", ] @@ -9829,6 +9831,22 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "url", + "webpki-roots 0.26.11", +] + [[package]] name = "ureq" version = "3.1.4" diff --git a/Cargo.toml b/Cargo.toml index 0be7d8a9c..e89a33689 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -291,6 +291,7 @@ trait-variant = "0.1.2" tungstenite = "0.28.0" twox-hash = { version = "2.1.2", features = ["xxhash32"] } ulid = "1.2.1" +ureq = "2.10" uuid = { version = "1.20.0", features = [ "v4", "v7", diff --git a/DEPENDENCIES.md b/DEPENDENCIES.md index 924cbd10d..f7b7a253a 100644 --- a/DEPENDENCIES.md +++ b/DEPENDENCIES.md @@ -852,6 +852,7 @@ unicode-xid: 0.2.6, "Apache-2.0 OR MIT", universal-hash: 0.5.1, "Apache-2.0 OR MIT", unsafe-libyaml: 0.2.11, "MIT", untrusted: 0.9.0, "ISC", +ureq: 2.12.1, "Apache-2.0 OR MIT", ureq: 3.1.4, "Apache-2.0 OR MIT", ureq-proto: 0.5.3, "Apache-2.0 OR MIT", url: 2.5.8, "Apache-2.0 OR MIT", diff --git a/core/server/Cargo.toml b/core/server/Cargo.toml index 7f84a50b6..25b801dd8 100644 --- a/core/server/Cargo.toml +++ b/core/server/Cargo.toml @@ -37,6 +37,12 @@ disable-mimalloc = [] mimalloc = ["dep:mimalloc"] iggy-web = ["dep:rust-embed", "dep:mime_guess"] +[target.'cfg(not(target_env = "musl"))'.dependencies] +hwlocality = { workspace = true } + +[target.'cfg(target_env = "musl")'.dependencies] +hwlocality = { workspace = true, features = ["vendored"] } + [dependencies] ahash = { workspace = true } anyhow = { workspace = true } @@ -96,6 +102,7 @@ rustls = { workspace = true } rustls-pemfile = { workspace = true } send_wrapper = { workspace = true } serde = { workspace = true } +serde_json.workspace = true serde_with = { workspace = true } slab = { workspace = true } socket2 = { workspace = true } @@ -112,15 +119,10 @@ tracing-appender = { workspace = true } tracing-opentelemetry = { workspace = true } tracing-subscriber = { workspace = true } tungstenite = { workspace = true } -ulid = { workspace = true } +ulid = "1.2.1" +ureq = { workspace = true } uuid = { workspace = true } -[target.'cfg(not(target_env = "musl"))'.dependencies] -hwlocality = { workspace = true } - -[target.'cfg(target_env = "musl")'.dependencies] -hwlocality = { workspace = true, features = ["vendored"] } - [build-dependencies] figment = { workspace = true, features = ["json", "toml", "env"] } vergen-git2 = { workspace = true } diff --git a/core/server/config.toml b/core/server/config.toml index 3949b9cd5..8599a2b23 100644 --- a/core/server/config.toml +++ b/core/server/config.toml @@ -117,6 +117,12 @@ decoding_secret = "top_secret$iggy123$_jwt_HS256_key#!" # `false` means the secret is in plain text. use_base64_secret = false +# Trusted issuers for A2A (Application-to-Application) authentication +[[http.jwt.trusted_issuers]] +issuer = "test-issuer" +jwks_url = "http://127.0.0.1:8081/.well-known/jwks.json" +audience = "iggy.apache.org" + # Metrics configuration for HTTP. [http.metrics] # Enable or disable the metrics endpoint. diff --git a/core/server/src/configs/defaults.rs b/core/server/src/configs/defaults.rs index ce9a84ab4..d88e15f36 100644 --- a/core/server/src/configs/defaults.rs +++ b/core/server/src/configs/defaults.rs @@ -264,6 +264,7 @@ impl Default for HttpJwtConfig { encoding_secret: SERVER_CONFIG.http.jwt.encoding_secret.parse().unwrap(), decoding_secret: SERVER_CONFIG.http.jwt.decoding_secret.parse().unwrap(), use_base64_secret: SERVER_CONFIG.http.jwt.use_base_64_secret, + trusted_issuers: None, } } } diff --git a/core/server/src/configs/http.rs b/core/server/src/configs/http.rs index b19b53cae..7e4c3383c 100644 --- a/core/server/src/configs/http.rs +++ b/core/server/src/configs/http.rs @@ -26,6 +26,13 @@ use serde::{Deserialize, Serialize}; use serde_with::DisplayFromStr; use serde_with::serde_as; +#[derive(Debug, Deserialize, Serialize, Clone, ConfigEnv)] +pub struct TrustedIssuerConfig { + pub issuer: String, + pub audience: String, + pub jwks_url: String, +} + #[derive(Debug, Deserialize, Serialize, Clone, ConfigEnv)] pub struct HttpConfig { pub enabled: bool, @@ -72,6 +79,8 @@ pub struct HttpJwtConfig { #[config_env(secret)] pub decoding_secret: String, pub use_base64_secret: bool, + #[serde(default)] + pub trusted_issuers: Option>, } #[derive(Debug, Deserialize, Serialize, Clone, ConfigEnv)] diff --git a/core/server/src/http/jwt/json_web_token.rs b/core/server/src/http/jwt/json_web_token.rs index 93862c77e..ee65828fc 100644 --- a/core/server/src/http/jwt/json_web_token.rs +++ b/core/server/src/http/jwt/json_web_token.rs @@ -33,7 +33,7 @@ pub struct JwtClaims { pub jti: String, pub iss: String, pub aud: String, - pub sub: u32, + pub sub: String, pub iat: u64, pub exp: u64, pub nbf: u64, diff --git a/core/server/src/http/jwt/jwks.rs b/core/server/src/http/jwt/jwks.rs new file mode 100644 index 000000000..fa3721c08 --- /dev/null +++ b/core/server/src/http/jwt/jwks.rs @@ -0,0 +1,160 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use iggy_common::locking::{IggyRwLock, IggyRwLockFn}; +use jsonwebtoken::DecodingKey; +use serde::Deserialize; +use serde_json; +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +struct Jwk { + kty: String, + kid: Option, + n: Option, + e: Option, + x: Option, + y: Option, + crv: Option, +} + +#[derive(Debug, Deserialize)] +struct JwkSet { + keys: Vec, +} + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +struct CacheKey { + issuer: String, + kid: String, +} + +#[derive(Debug, Clone)] +pub struct JwksClient { + cache: Arc>>, +} + +impl Default for JwksClient { + fn default() -> Self { + Self { + cache: Arc::new(IggyRwLock::new(HashMap::new())), + } + } +} + +impl JwksClient { + pub async fn get_key(&self, issuer: &str, jwks_url: &str, kid: &str) -> Option { + let cache_key = CacheKey { + issuer: issuer.to_string(), + kid: kid.to_string(), + }; + + { + let cache = self.cache.read().await; + if let Some(key) = cache.get(&cache_key) { + return Some(key.clone()); + } + } + + if let Ok(key) = self.fetch_and_cache_key(issuer, jwks_url, kid).await { + return Some(key); + } + + None + } + + async fn fetch_and_cache_key( + &self, + issuer: &str, + jwks_url: &str, + kid: &str, + ) -> Result { + if let Err(e) = self.refresh_keys(issuer, jwks_url).await { + return Err(anyhow::anyhow!("Failed to refresh keys: {}", e)); + } + + let cache_key = CacheKey { + issuer: issuer.to_string(), + kid: kid.to_string(), + }; + + let cache = self.cache.read().await; + cache + .get(&cache_key) + .cloned() + .ok_or_else(|| anyhow::anyhow!("Key not found in cache after refresh")) + } + + async fn refresh_keys(&self, issuer: &str, jwks_url: &str) -> Result<(), anyhow::Error> { + let response = ureq::get(jwks_url) + .call() + .map_err(|e| anyhow::anyhow!("Failed to fetch JWKS: {}", e))?; + + let body = response + .into_string() + .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?; + + let jwks: JwkSet = serde_json::from_str(&body) + .map_err(|e| anyhow::anyhow!("Failed to parse JWKS: {}", e))?; + + let mut cache = self.cache.write().await; + + for key in jwks.keys { + if let Some(kid) = key.kid { + let decoding_key: DecodingKey = match key.kty.as_str() { + "RSA" => { + if let (Some(n), Some(e)) = (key.n.as_deref(), key.e.as_deref()) { + DecodingKey::from_rsa_components(n, e) + .map_err(|e| anyhow::anyhow!("Invalid RSA key: {}", e))? + } else { + continue; + } + } + "EC" => { + if let (Some(x), Some(y), Some(crv)) = + (key.x.as_deref(), key.y.as_deref(), key.crv.as_deref()) + { + match crv { + "P-256" => DecodingKey::from_ec_components(x, y) + .map_err(|e| anyhow::anyhow!("Invalid EC key: {}", e))?, + "P-384" => DecodingKey::from_ec_components(x, y) + .map_err(|e| anyhow::anyhow!("Invalid EC key: {}", e))?, + "P-521" => DecodingKey::from_ec_components(x, y) + .map_err(|e| anyhow::anyhow!("Invalid EC key: {}", e))?, + _ => continue, + } + } else { + continue; + } + } + _ => continue, + }; + + let cache_key = CacheKey { + issuer: issuer.to_string(), + kid, + }; + cache.insert(cache_key, decoding_key); + } + } + + Ok(()) + } +} diff --git a/core/server/src/http/jwt/jwt_manager.rs b/core/server/src/http/jwt/jwt_manager.rs index 8980a1389..72bba1f70 100644 --- a/core/server/src/http/jwt/jwt_manager.rs +++ b/core/server/src/http/jwt/jwt_manager.rs @@ -16,9 +16,10 @@ * under the License. */ -use crate::configs::http::HttpJwtConfig; +use crate::configs::http::{HttpJwtConfig, TrustedIssuerConfig}; use crate::http::jwt::COMPONENT; use crate::http::jwt::json_web_token::{GeneratedToken, JwtClaims, RevokedAccessToken}; +use crate::http::jwt::jwks::JwksClient; use crate::http::jwt::storage::TokenStorage; use crate::streaming::persistence::persister::PersisterKind; use ahash::AHashMap; @@ -31,6 +32,7 @@ use iggy_common::UserId; use iggy_common::locking::IggyRwLock; use iggy_common::locking::IggyRwLockFn; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, encode}; +use std::collections::HashMap; use std::sync::Arc; use tracing::{debug, error, info}; @@ -56,6 +58,8 @@ pub struct JwtManager { tokens_storage: TokenStorage, revoked_tokens: IggyRwLock>, validations: AHashMap, + jwks_client: JwksClient, + trusted_issuer: HashMap, } impl JwtManager { @@ -78,6 +82,8 @@ impl JwtManager { validator, tokens_storage: TokenStorage::new(persister, path), revoked_tokens: IggyRwLock::new(AHashMap::new()), + jwks_client: JwksClient::default(), + trusted_issuer: HashMap::new(), }) } @@ -105,7 +111,17 @@ impl JwtManager { format!("{COMPONENT} (error: {e}) - failed to get decoding key") })?, }; - JwtManager::new(persister, path, issuer, validator) + let mut manager = JwtManager::new(persister, path, issuer, validator)?; + + if let Some(trusted_issuers) = config.trusted_issuers.as_ref() { + for issuer_config in trusted_issuers { + manager + .trusted_issuer + .insert(issuer_config.issuer.clone(), issuer_config.clone()); + } + } + + Ok(manager) } fn create_validation( @@ -181,7 +197,7 @@ impl JwtManager { let nbf = iat + self.issuer.not_before.as_secs() as u64; let claims = JwtClaims { jti: uuid::Uuid::now_v7().to_string(), - sub: user_id, + sub: user_id.to_string(), aud: self.issuer.audience.to_string(), iss: self.issuer.issuer.to_string(), iat, @@ -210,7 +226,7 @@ impl JwtManager { let token_header = jsonwebtoken::decode_header(token).map_err(|_| IggyError::InvalidAccessToken)?; - let jwt_claims = self.decode(token, token_header.alg)?; + let jwt_claims = self.decode(token, token_header.alg).await?; let id = jwt_claims.claims.jti; let expiry = jwt_claims.claims.exp; if self @@ -232,15 +248,63 @@ impl JwtManager { .error(|e: &IggyError| { format!("{COMPONENT} (error: {e}) - failed to save revoked access token: {id}") })?; - self.generate(jwt_claims.claims.sub) + let user_id = jwt_claims + .claims + .sub + .parse::() + .map_err(|_| IggyError::InvalidAccessToken)?; + self.generate(user_id) } - pub fn decode( + pub async fn decode( &self, token: &str, algorithm: Algorithm, ) -> Result, IggyError> { let validation = self.validations.get(&algorithm); + let kid = jsonwebtoken::decode_header(token).ok().and_then(|h| h.kid); + + #[allow(clippy::collapsible_if)] + if let Ok(insecure) = jsonwebtoken::dangerous::insecure_decode::(token) { + debug!( + "JWT decoded insecurely, issuer: {}, kid: {:?}", + insecure.claims.iss, kid + ); + if let Some(config) = self.trusted_issuer.get(&insecure.claims.iss) { + debug!("Found trusted issuer config: {}", config.issuer); + if let Some(kid_str) = kid.as_deref() { + if let Some(decoding_key) = self + .jwks_client + .get_key(&config.issuer, &config.jwks_url, kid_str) + .await + { + debug!("Got decoding key from JWKS for kid: {}", kid_str); + let mut validation = Validation::new(algorithm); + validation.set_issuer(std::slice::from_ref(&config.issuer)); + validation.set_audience(std::slice::from_ref(&config.audience)); + debug!("Validation configured, attempting to decode JWT"); + return jsonwebtoken::decode::( + token, + &decoding_key, + &validation, + ) + .map_err(|e| { + error!("Failed to decode JWT: {}", e); + IggyError::Unauthenticated + }); + } else { + error!("Failed to get decoding key from JWKS for kid: {}", kid_str); + } + } else { + error!("No kid found in JWT header"); + } + } else { + debug!("No trusted issuer found for: {}", insecure.claims.iss); + } + } else { + error!("Failed to decode JWT insecurely"); + } + if validation.is_none() { return Err(IggyError::InvalidJwtAlgorithm( Self::map_algorithm_to_string(algorithm), diff --git a/core/server/src/http/jwt/middleware.rs b/core/server/src/http/jwt/middleware.rs index 910e02ccb..dfef00b03 100644 --- a/core/server/src/http/jwt/middleware.rs +++ b/core/server/src/http/jwt/middleware.rs @@ -26,7 +26,6 @@ use axum::{ response::Response, }; use err_trail::ErrContext; -use iggy_common::IggyError; use std::sync::Arc; const COMPONENT: &str = "JWT_MIDDLEWARE"; @@ -79,9 +78,7 @@ pub async fn jwt_auth( let jwt_claims = state .jwt_manager .decode(jwt_token, token_header.alg) - .error(|e: &IggyError| { - format!("{COMPONENT} (error: {e}) - failed to decode JWT with provided algorithm") - }) + .await .map_err(|_| UNAUTHORIZED)?; if state .jwt_manager @@ -92,10 +89,15 @@ pub async fn jwt_auth( } let request_details = request.extensions().get::().unwrap(); + let user_id = jwt_claims + .claims + .sub + .parse::() + .map_err(|_| UNAUTHORIZED)?; let identity = Identity { token_id: jwt_claims.claims.jti, token_expiry: jwt_claims.claims.exp, - user_id: jwt_claims.claims.sub, + user_id, ip_address: request_details.ip_address, }; request.extensions_mut().insert(identity); diff --git a/core/server/src/http/jwt/mod.rs b/core/server/src/http/jwt/mod.rs index 5481f110c..d5e413cb4 100644 --- a/core/server/src/http/jwt/mod.rs +++ b/core/server/src/http/jwt/mod.rs @@ -17,6 +17,7 @@ */ pub mod json_web_token; +pub mod jwks; pub mod jwt_manager; pub mod middleware; pub mod storage;