diff --git a/src/auth/src/credentials/external_account.rs b/src/auth/src/credentials/external_account.rs index 877de0c47c..cef2a26c27 100644 --- a/src/auth/src/credentials/external_account.rs +++ b/src/auth/src/credentials/external_account.rs @@ -125,6 +125,7 @@ use crate::headers_util::build_cacheable_headers; use crate::retry::Builder as RetryTokenProviderBuilder; use crate::token::{CachedTokenProvider, Token, TokenProvider}; use crate::token_cache::TokenCache; +use crate::trust_boundary::TrustBoundary; use crate::{BuildResult, Result}; use gax::backoff_policy::BackoffPolicyArg; use gax::retry_policy::RetryPolicyArg; @@ -359,16 +360,21 @@ impl ExternalAccountConfig { where T: dynamic::SubjectTokenProvider + 'static, { + let trust_boundary_url = + crate::trust_boundary::external_account_lookup_url(&config.audience); let token_provider = ExternalAccountTokenProvider { subject_token_provider, config, }; let token_provider_with_retry = retry_builder.build(token_provider); let cache = TokenCache::new(token_provider_with_retry); + let trust_boundary = + trust_boundary_url.map(|url| Arc::new(TrustBoundary::new(cache.clone(), url))); AccessTokenCredentials { inner: Arc::new(ExternalAccountCredentials { token_provider: cache, quota_project_id, + trust_boundary, }), } } @@ -457,6 +463,7 @@ where { token_provider: T, quota_project_id: Option, + trust_boundary: Option>, } /// A builder for external account [Credentials] instances. @@ -1279,7 +1286,12 @@ where { async fn headers(&self, extensions: Extensions) -> Result> { let token = self.token_provider.token(extensions).await?; - build_cacheable_headers(&token, &self.quota_project_id) + let trust_boundary_header = self + .trust_boundary + .as_ref() + .and_then(|tb| tb.header_value()); + + build_cacheable_headers(&token, &self.quota_project_id, &trust_boundary_header) } } diff --git a/src/auth/src/credentials/impersonated.rs b/src/auth/src/credentials/impersonated.rs index a4af732861..6ac04098f7 100644 --- a/src/auth/src/credentials/impersonated.rs +++ b/src/auth/src/credentials/impersonated.rs @@ -706,7 +706,7 @@ where { async fn headers(&self, extensions: Extensions) -> Result> { let token = self.token_provider.token(extensions).await?; - build_cacheable_headers(&token, &self.quota_project_id) + build_cacheable_headers(&token, &self.quota_project_id, &None) } } diff --git a/src/auth/src/credentials/mds.rs b/src/auth/src/credentials/mds.rs index fac0c5b819..6c1c6cf0ad 100644 --- a/src/auth/src/credentials/mds.rs +++ b/src/auth/src/credentials/mds.rs @@ -81,6 +81,7 @@ use crate::mds::client::Client as MDSClient; use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry}; use crate::token::{CachedTokenProvider, Token, TokenProvider}; use crate::token_cache::TokenCache; +use crate::trust_boundary::TrustBoundary; use crate::{BuildResult, Result}; use async_trait::async_trait; use gax::backoff_policy::BackoffPolicyArg; @@ -107,6 +108,7 @@ where { quota_project_id: Option, token_provider: T, + trust_boundary: Arc, } /// Creates [Credentials] instances backed by the [Metadata Service]. @@ -282,9 +284,19 @@ impl Builder { /// # }); /// ``` pub fn build_access_token_credentials(self) -> BuildResult { + let quota_project_id = self.quota_project_id.clone(); + let mds_client = MDSClient::new(self.endpoint.clone()); + let token_provider = TokenCache::new(self.build_token_provider()); + + let trust_boundary = Arc::new(TrustBoundary::new_for_mds( + token_provider.clone(), + mds_client.clone(), + )); + let mdsc = MDSCredentials { - quota_project_id: self.quota_project_id.clone(), - token_provider: TokenCache::new(self.build_token_provider()), + quota_project_id, + token_provider, + trust_boundary, }; Ok(AccessTokenCredentials { inner: Arc::new(mdsc), @@ -338,7 +350,12 @@ where { async fn headers(&self, extensions: Extensions) -> Result> { let cached_token = self.token_provider.token(extensions).await?; - build_cacheable_headers(&cached_token, &self.quota_project_id) + let trust_boundary_header_value = self.trust_boundary.header_value(); + build_cacheable_headers( + &cached_token, + &self.quota_project_id, + &trust_boundary_header_value, + ) } } @@ -567,9 +584,11 @@ mod tests { let mut mock = MockTokenProvider::new(); mock.expect_token().times(1).return_once(|| Ok(token)); + let cache = TokenCache::new(mock); let mdsc = MDSCredentials { quota_project_id: None, - token_provider: TokenCache::new(mock), + token_provider: cache.clone(), + trust_boundary: Arc::new(TrustBoundary::new(cache, "http://localhost".to_string())), }; let mut extensions = Extensions::new(); @@ -627,9 +646,11 @@ mod tests { .times(1) .return_once(|| Err(errors::non_retryable_from_str("fail"))); + let cache = TokenCache::new(mock); let mdsc = MDSCredentials { quota_project_id: None, - token_provider: TokenCache::new(mock), + token_provider: cache.clone(), + trust_boundary: Arc::new(TrustBoundary::new(cache, "http://localhost".to_string())), }; assert!(mdsc.headers(Extensions::new()).await.is_err()); } diff --git a/src/auth/src/credentials/service_account.rs b/src/auth/src/credentials/service_account.rs index 64e05ee7e1..04b8e63f0b 100644 --- a/src/auth/src/credentials/service_account.rs +++ b/src/auth/src/credentials/service_account.rs @@ -80,6 +80,7 @@ use crate::errors::{self}; use crate::headers_util::build_cacheable_headers; use crate::token::{CachedTokenProvider, Token, TokenProvider}; use crate::token_cache::TokenCache; +use crate::trust_boundary::TrustBoundary; use crate::{BuildResult, Result}; use async_trait::async_trait; use http::{Extensions, HeaderMap}; @@ -328,10 +329,21 @@ impl Builder { /// /// [service account keys]: https://cloud.google.com/iam/docs/keys-create-delete#creating pub fn build_access_token_credentials(self) -> BuildResult { + let quota_project_id = self.quota_project_id.clone(); + let token_provider = self.build_token_provider()?; + let client_email = token_provider.service_account_key.client_email.clone(); + + let token_provider = TokenCache::new(token_provider); + let trust_boundary_url = crate::trust_boundary::service_account_lookup_url(&client_email); + let trust_boundary = Arc::new(TrustBoundary::new( + token_provider.clone(), + trust_boundary_url, + )); Ok(AccessTokenCredentials { inner: Arc::new(ServiceAccountCredentials { - quota_project_id: self.quota_project_id.clone(), - token_provider: TokenCache::new(self.build_token_provider()?), + quota_project_id, + token_provider, + trust_boundary, }), }) } @@ -462,6 +474,7 @@ where { token_provider: T, quota_project_id: Option, + trust_boundary: Arc, } #[derive(Debug)] @@ -573,7 +586,9 @@ where { async fn headers(&self, extensions: Extensions) -> Result> { let token = self.token_provider.token(extensions).await?; - build_cacheable_headers(&token, &self.quota_project_id) + let trust_boundary_header = self.trust_boundary.header_value(); + + build_cacheable_headers(&token, &self.quota_project_id, &trust_boundary_header) } } @@ -652,9 +667,11 @@ mod tests { let mut mock = MockTokenProvider::new(); mock.expect_token().times(1).return_once(|| Ok(token)); + let cache = TokenCache::new(mock); let sac = ServiceAccountCredentials { - token_provider: TokenCache::new(mock), + token_provider: cache.clone(), quota_project_id: None, + trust_boundary: Arc::new(TrustBoundary::new(cache, "http://localhost".to_string())), }; let mut extensions = Extensions::new(); @@ -694,9 +711,11 @@ mod tests { let mut mock = MockTokenProvider::new(); mock.expect_token().times(1).return_once(|| Ok(token)); + let cache = TokenCache::new(mock); let sac = ServiceAccountCredentials { - token_provider: TokenCache::new(mock), + token_provider: cache.clone(), quota_project_id: Some(quota_project.to_string()), + trust_boundary: Arc::new(TrustBoundary::new(cache, "http://localhost".to_string())), }; let headers = get_headers_from_cache(sac.headers(Extensions::new()).await.unwrap())?; @@ -721,9 +740,11 @@ mod tests { .times(1) .return_once(|| Err(errors::non_retryable_from_str("fail"))); + let cache = TokenCache::new(mock); let sac = ServiceAccountCredentials { - token_provider: TokenCache::new(mock), + token_provider: cache.clone(), quota_project_id: None, + trust_boundary: Arc::new(TrustBoundary::new(cache, "http://localhost".to_string())), }; assert!(sac.headers(Extensions::new()).await.is_err()); } diff --git a/src/auth/src/credentials/user_account.rs b/src/auth/src/credentials/user_account.rs index 1e2fdd698e..7ffe04f19d 100644 --- a/src/auth/src/credentials/user_account.rs +++ b/src/auth/src/credentials/user_account.rs @@ -505,7 +505,7 @@ where { async fn headers(&self, extensions: Extensions) -> Result> { let token = self.token_provider.token(extensions).await?; - build_cacheable_headers(&token, &self.quota_project_id) + build_cacheable_headers(&token, &self.quota_project_id, &None) } } diff --git a/src/auth/src/headers_util.rs b/src/auth/src/headers_util.rs index cd4932b54e..28469776c2 100644 --- a/src/auth/src/headers_util.rs +++ b/src/auth/src/headers_util.rs @@ -20,6 +20,7 @@ use crate::Result; use crate::credentials::{CacheableResource, QUOTA_PROJECT_KEY}; use crate::errors; use crate::token::Token; +use crate::trust_boundary::TRUST_BOUNDARY_HEADER; use http::HeaderMap; use http::header::{AUTHORIZATION, HeaderName, HeaderValue}; @@ -55,11 +56,12 @@ const API_KEY_HEADER_KEY: &str = "x-goog-api-key"; pub(crate) fn build_cacheable_headers( cached_token: &CacheableResource, quota_project_id: &Option, + trust_boundary_header: &Option, ) -> Result> { match cached_token { CacheableResource::NotModified => Ok(CacheableResource::NotModified), CacheableResource::New { entity_tag, data } => { - let headers = build_bearer_headers(data, quota_project_id)?; + let headers = build_bearer_headers(data, quota_project_id, trust_boundary_header)?; Ok(CacheableResource::New { entity_tag: entity_tag.clone(), data: headers, @@ -72,11 +74,18 @@ pub(crate) fn build_cacheable_headers( fn build_bearer_headers( token: &crate::token::Token, quota_project_id: &Option, + trust_boundary_header: &Option, ) -> Result { - build_headers(token, quota_project_id, AUTHORIZATION, |token| { - HeaderValue::from_str(&format!("{} {}", token.token_type, token.token)) - .map_err(errors::non_retryable) - }) + build_headers( + token, + quota_project_id, + trust_boundary_header, + AUTHORIZATION, + |token| { + HeaderValue::from_str(&format!("{} {}", token.token_type, token.token)) + .map_err(errors::non_retryable) + }, + ) } pub(crate) fn build_cacheable_api_key_headers( @@ -99,6 +108,7 @@ fn build_api_key_headers(token: &crate::token::Token) -> Result { build_headers( token, &None, + &None, HeaderName::from_static(API_KEY_HEADER_KEY), |token| HeaderValue::from_str(&token.token).map_err(errors::non_retryable), ) @@ -108,6 +118,7 @@ fn build_api_key_headers(token: &crate::token::Token) -> Result { fn build_headers( token: &crate::token::Token, quota_project_id: &Option, + trust_boundary_header: &Option, header_name: HeaderName, build_header_value: impl FnOnce(&crate::token::Token) -> Result, ) -> Result { @@ -124,6 +135,13 @@ fn build_headers( ); } + if let Some(trust_boundary) = trust_boundary_header { + header_map.insert( + HeaderName::from_static(TRUST_BOUNDARY_HEADER), + HeaderValue::from_str(trust_boundary).map_err(errors::non_retryable)?, + ); + } + Ok(header_map) } @@ -151,7 +169,7 @@ mod tests { data: token, }; - let result = build_cacheable_headers(&cacheable_token, &None); + let result = build_cacheable_headers(&cacheable_token, &None, &None); assert!(result.is_ok()); let cached_headers = result.unwrap(); @@ -173,7 +191,7 @@ mod tests { fn build_cacheable_headers_basic_not_modified() { let cacheable_token = CacheableResource::NotModified; - let result = build_cacheable_headers(&cacheable_token, &None); + let result = build_cacheable_headers(&cacheable_token, &None, &None); assert!(result.is_ok()); let cached_headers = result.unwrap(); @@ -192,7 +210,7 @@ mod tests { }; let quota_project_id = Some("test-project-123".to_string()); - let result = build_cacheable_headers(&cacheable_token, "a_project_id); + let result = build_cacheable_headers(&cacheable_token, "a_project_id, &None); assert!(result.is_ok()); let cached_headers = result.unwrap(); @@ -214,11 +232,45 @@ mod tests { assert_eq!(quota_project, HeaderValue::from_static("test-project-123")); } + #[test] + fn build_cacheable_headers_with_trust_boundary_success() { + let token = create_test_token("test_token", "Bearer"); + let cacheable_token = CacheableResource::New { + entity_tag: EntityTag::default(), + data: token, + }; + + let trust_boundary = Some("test-trust-boundary".to_string()); + let result = build_cacheable_headers(&cacheable_token, &None, &trust_boundary); + + assert!(result.is_ok()); + let cached_headers = result.unwrap(); + let headers = match cached_headers { + CacheableResource::New { data, .. } => data, + CacheableResource::NotModified => unreachable!("expecting new headers"), + }; + assert_eq!(headers.len(), 2, "{headers:?}"); + + let token = headers + .get(HeaderName::from_static("authorization")) + .unwrap(); + assert_eq!(token, HeaderValue::from_static("Bearer test_token")); + assert!(token.is_sensitive()); + + let trust_boundary = headers + .get(HeaderName::from_static(TRUST_BOUNDARY_HEADER)) + .unwrap(); + assert_eq!( + trust_boundary, + HeaderValue::from_static("test-trust-boundary") + ); + } + #[test] fn build_bearer_headers_different_token_type() { let token = create_test_token("special_token", "MAC"); - let result = build_bearer_headers(&token, &None); + let result = build_bearer_headers(&token, &None, &None); assert!(result.is_ok()); let headers = result.unwrap(); @@ -237,7 +289,7 @@ mod tests { fn build_bearer_headers_invalid_token() { let token = create_test_token("token with \n invalid chars", "Bearer"); - let result = build_bearer_headers(&token, &None); + let result = build_bearer_headers(&token, &None, &None); assert!(result.is_err()); let error = result.unwrap_err(); diff --git a/src/auth/src/lib.rs b/src/auth/src/lib.rs index 588d0342b6..0ad091a8c1 100644 --- a/src/auth/src/lib.rs +++ b/src/auth/src/lib.rs @@ -62,6 +62,7 @@ pub(crate) mod retry; pub mod signer; pub(crate) mod token; pub(crate) mod token_cache; +pub(crate) mod trust_boundary; /// A `Result` alias where the `Err` case is [BuildCredentialsError]. pub(crate) type BuildResult = std::result::Result; diff --git a/src/auth/src/trust_boundary.rs b/src/auth/src/trust_boundary.rs new file mode 100644 index 0000000000..3587e8a7dc --- /dev/null +++ b/src/auth/src/trust_boundary.rs @@ -0,0 +1,278 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::credentials::CacheableResource; +use crate::errors::CredentialsError; +use crate::headers_util::build_cacheable_headers; +use crate::mds::client::Client as MDSClient; +use crate::token::CachedTokenProvider; +use http::Extensions; +use reqwest::Client; +use std::clone::Clone; +use std::fmt::Debug; +use tokio::sync::watch; +use tokio::time::{Duration, sleep}; + +pub(crate) const TRUST_BOUNDARY_HEADER: &str = "x-goog-allowed-locations"; +const TRUST_BOUNDARIES_ENV_VAR: &str = "GOOGLE_AUTH_ENABLE_TRUST_BOUNDARIES"; +const NO_OP_ENCODED_LOCATIONS: &str = "0x0"; + +// Refresh interval: 1 hour +const REFRESH_INTERVAL: Duration = Duration::from_secs(3600); +// Retry interval on error: 1 minute +const ERROR_RETRY_INTERVAL: Duration = Duration::from_secs(60); + +#[derive(Debug)] +pub(crate) struct TrustBoundary { + rx_header: watch::Receiver>, +} + +impl TrustBoundary { + pub(crate) fn new(token_provider: T, url: String) -> Self + where + T: CachedTokenProvider + 'static, + { + let enabled = Self::is_trust_boundaries_enabled(); + let (tx_header, rx_header) = watch::channel(None); + + if enabled { + tokio::spawn(refresh_task(token_provider, url, tx_header)); + } + + Self { rx_header } + } + + pub(crate) fn new_for_mds(token_provider: T, mds_client: MDSClient) -> Self + where + T: CachedTokenProvider + 'static, + { + let enabled = Self::is_trust_boundaries_enabled(); + let (tx_header, rx_header) = watch::channel(None); + + if enabled { + tokio::spawn(refresh_task_mds(token_provider, mds_client, tx_header)); + } + + Self { rx_header } + } + + fn is_trust_boundaries_enabled() -> bool { + std::env::var(TRUST_BOUNDARIES_ENV_VAR) + .map(|v| v.to_lowercase() == "true") + .unwrap_or(false) + } + + pub(crate) fn header_value(&self) -> Option { + let val = self.rx_header.borrow().clone(); + if let Some(ref v) = val { + if v == NO_OP_ENCODED_LOCATIONS { + return None; + } + } + val + } +} + +#[derive(serde::Deserialize)] +struct AllowedLocationsResponse { + #[allow(dead_code)] + locations: Vec, + #[serde(rename = "encodedLocations")] + encoded_locations: String, +} + +async fn fetch_trust_boundary( + token_provider: &T, + url: &str, +) -> Result, CredentialsError> +where + T: CachedTokenProvider, +{ + let token = token_provider.token(Extensions::new()).await?; + let headers = build_cacheable_headers(&token, &None, &None)?; + let headers = match headers { + CacheableResource::New { data, .. } => data, + CacheableResource::NotModified => { + unreachable!("requested trust boundary without a caching etag") + } + }; + + let client = Client::new(); + + // TODO: retries ? + let resp = client + .get(url) + .headers(headers) + .send() + .await + .map_err(|e| CredentialsError::from_msg(true, e.to_string()))?; + + // TODO: add error handling - default fallback ? + if !resp.status().is_success() { + return Err(CredentialsError::from_msg( + true, + format!("Failed to fetch trust boundary: {}", resp.status()), + )); + } + + let response: AllowedLocationsResponse = resp + .json() + .await + .map_err(|e| CredentialsError::from_msg(true, e.to_string()))?; + + if !response.encoded_locations.is_empty() { + return Ok(Some(response.encoded_locations)); + } + + Ok(None) +} + +async fn refresh_task_mds( + token_provider: T, + mds_client: MDSClient, + tx_header: watch::Sender>, +) where + T: CachedTokenProvider, +{ + let mut url: Option = None; + + loop { + if url.is_none() { + let res = mds_client.email().await; + match res { + Ok(email) => { + url = Some(service_account_lookup_url(&email)); + } + Err(_e) => { + sleep(ERROR_RETRY_INTERVAL).await; + continue; + } + } + } + + if let Some(ref url) = url { + fetch_and_update(&token_provider, url, &tx_header).await; + } + } +} + +async fn refresh_task(token_provider: T, url: String, tx_header: watch::Sender>) +where + T: CachedTokenProvider, +{ + loop { + fetch_and_update(&token_provider, &url, &tx_header).await; + } +} + +async fn fetch_and_update( + token_provider: &T, + url: &str, + tx_header: &watch::Sender>, +) where + T: CachedTokenProvider, +{ + match fetch_trust_boundary(token_provider, url).await { + Ok(val) => { + let _ = tx_header.send(val); + sleep(REFRESH_INTERVAL).await; + } + Err(_e) => { + // TODO: better error handling - default fallback ? + sleep(ERROR_RETRY_INTERVAL).await; + } + } +} + +pub(crate) fn service_account_lookup_url(email: &str) -> String { + format!( + "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations", + email + ) +} + +pub(crate) fn external_account_lookup_url(audience: &str) -> Option { + let path = audience + .trim_start_matches("//iam.googleapis.com/") + .trim_start_matches("https://iam.googleapis.com/") + .trim_start_matches('/'); + + let parts: Vec<&str> = path.split('/').collect(); + + // Workload: projects/{project}/locations/global/workloadIdentityPools/{pool}/providers/{provider} (6 parts) + if parts.len() >= 6 + && parts[0] == "projects" + && parts[2] == "locations" + && parts[4] == "workloadIdentityPools" + { + let project = parts[1]; + let pool = parts[5]; + return Some(format!( + "https://iamcredentials.googleapis.com/v1/projects/{}/locations/global/workloadIdentityPools/{}/allowedLocations", + project, pool + )); + } + + // Workforce: locations/global/workforcePools/{pool}/providers/{provider} (4 parts) + if parts.len() >= 4 && parts[0] == "locations" && parts[2] == "workforcePools" { + let pool = parts[3]; + return Some(format!( + "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/{}/allowedLocations", + pool + )); + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_service_account_url() { + assert_eq!( + service_account_lookup_url("sa@project.iam.gserviceaccount.com"), + "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa@project.iam.gserviceaccount.com/allowedLocations" + ); + } + + #[test] + fn test_external_account_url_workload() { + let aud = "//iam.googleapis.com/projects/12345/locations/global/workloadIdentityPools/my-pool/providers/my-provider"; + assert_eq!( + external_account_lookup_url(aud).unwrap(), + "https://iamcredentials.googleapis.com/v1/projects/12345/locations/global/workloadIdentityPools/my-pool/allowedLocations" + ); + } + + #[test] + fn test_external_account_url_workforce() { + let aud = + "//iam.googleapis.com/locations/global/workforcePools/my-pool/providers/my-provider"; + assert_eq!( + external_account_lookup_url(aud).unwrap(), + "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/my-pool/allowedLocations" + ); + } + + #[test] + fn test_external_account_url_invalid() { + assert!(external_account_lookup_url("invalid").is_none()); + assert!( + external_account_lookup_url("//iam.googleapis.com/projects/123/locations/global/wrong") + .is_none() + ); + } +}