Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/auth/src/credentials/external_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ use crate::headers_util::build_cacheable_headers;
use crate::retry::Builder as RetryTokenProviderBuilder;
use crate::token::{CachedTokenProvider, Token, TokenProvider};
use crate::token_cache::TokenCache;
use crate::trust_boundary::TrustBoundary;
use crate::{BuildResult, Result};
use gax::backoff_policy::BackoffPolicyArg;
use gax::retry_policy::RetryPolicyArg;
Expand Down Expand Up @@ -359,16 +360,21 @@ impl ExternalAccountConfig {
where
T: dynamic::SubjectTokenProvider + 'static,
{
let trust_boundary_url =
crate::trust_boundary::external_account_lookup_url(&config.audience);
let token_provider = ExternalAccountTokenProvider {
subject_token_provider,
config,
};
let token_provider_with_retry = retry_builder.build(token_provider);
let cache = TokenCache::new(token_provider_with_retry);
let trust_boundary =
trust_boundary_url.map(|url| Arc::new(TrustBoundary::new(cache.clone(), url)));
AccessTokenCredentials {
inner: Arc::new(ExternalAccountCredentials {
token_provider: cache,
quota_project_id,
trust_boundary,
}),
}
}
Expand Down Expand Up @@ -457,6 +463,7 @@ where
{
token_provider: T,
quota_project_id: Option<String>,
trust_boundary: Option<Arc<TrustBoundary>>,
}

/// A builder for external account [Credentials] instances.
Expand Down Expand Up @@ -1279,7 +1286,12 @@ where
{
async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
let token = self.token_provider.token(extensions).await?;
build_cacheable_headers(&token, &self.quota_project_id)
let trust_boundary_header = self
.trust_boundary
.as_ref()
.and_then(|tb| tb.header_value());

build_cacheable_headers(&token, &self.quota_project_id, &trust_boundary_header)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/auth/src/credentials/impersonated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ where
{
async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
let token = self.token_provider.token(extensions).await?;
build_cacheable_headers(&token, &self.quota_project_id)
build_cacheable_headers(&token, &self.quota_project_id, &None)
}
}

Expand Down
31 changes: 26 additions & 5 deletions src/auth/src/credentials/mds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ use crate::mds::client::Client as MDSClient;
use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
use crate::token::{CachedTokenProvider, Token, TokenProvider};
use crate::token_cache::TokenCache;
use crate::trust_boundary::TrustBoundary;
use crate::{BuildResult, Result};
use async_trait::async_trait;
use gax::backoff_policy::BackoffPolicyArg;
Expand All @@ -107,6 +108,7 @@ where
{
quota_project_id: Option<String>,
token_provider: T,
trust_boundary: Arc<TrustBoundary>,
}

/// Creates [Credentials] instances backed by the [Metadata Service].
Expand Down Expand Up @@ -282,9 +284,19 @@ impl Builder {
/// # });
/// ```
pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
let quota_project_id = self.quota_project_id.clone();
let mds_client = MDSClient::new(self.endpoint.clone());
let token_provider = TokenCache::new(self.build_token_provider());

let trust_boundary = Arc::new(TrustBoundary::new_for_mds(
token_provider.clone(),
mds_client.clone(),
));

let mdsc = MDSCredentials {
quota_project_id: self.quota_project_id.clone(),
token_provider: TokenCache::new(self.build_token_provider()),
quota_project_id,
token_provider,
trust_boundary,
};
Ok(AccessTokenCredentials {
inner: Arc::new(mdsc),
Expand Down Expand Up @@ -338,7 +350,12 @@ where
{
async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
let cached_token = self.token_provider.token(extensions).await?;
build_cacheable_headers(&cached_token, &self.quota_project_id)
let trust_boundary_header_value = self.trust_boundary.header_value();
build_cacheable_headers(
&cached_token,
&self.quota_project_id,
&trust_boundary_header_value,
)
}
}

Expand Down Expand Up @@ -567,9 +584,11 @@ mod tests {
let mut mock = MockTokenProvider::new();
mock.expect_token().times(1).return_once(|| Ok(token));

let cache = TokenCache::new(mock);
let mdsc = MDSCredentials {
quota_project_id: None,
token_provider: TokenCache::new(mock),
token_provider: cache.clone(),
trust_boundary: Arc::new(TrustBoundary::new(cache, "http://localhost".to_string())),
};

let mut extensions = Extensions::new();
Expand Down Expand Up @@ -627,9 +646,11 @@ mod tests {
.times(1)
.return_once(|| Err(errors::non_retryable_from_str("fail")));

let cache = TokenCache::new(mock);
let mdsc = MDSCredentials {
quota_project_id: None,
token_provider: TokenCache::new(mock),
token_provider: cache.clone(),
trust_boundary: Arc::new(TrustBoundary::new(cache, "http://localhost".to_string())),
};
assert!(mdsc.headers(Extensions::new()).await.is_err());
}
Expand Down
33 changes: 27 additions & 6 deletions src/auth/src/credentials/service_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ use crate::errors::{self};
use crate::headers_util::build_cacheable_headers;
use crate::token::{CachedTokenProvider, Token, TokenProvider};
use crate::token_cache::TokenCache;
use crate::trust_boundary::TrustBoundary;
use crate::{BuildResult, Result};
use async_trait::async_trait;
use http::{Extensions, HeaderMap};
Expand Down Expand Up @@ -328,10 +329,21 @@ impl Builder {
///
/// [service account keys]: https://cloud.google.com/iam/docs/keys-create-delete#creating
pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
let quota_project_id = self.quota_project_id.clone();
let token_provider = self.build_token_provider()?;
let client_email = token_provider.service_account_key.client_email.clone();

let token_provider = TokenCache::new(token_provider);
let trust_boundary_url = crate::trust_boundary::service_account_lookup_url(&client_email);
let trust_boundary = Arc::new(TrustBoundary::new(
token_provider.clone(),
trust_boundary_url,
));
Ok(AccessTokenCredentials {
inner: Arc::new(ServiceAccountCredentials {
quota_project_id: self.quota_project_id.clone(),
token_provider: TokenCache::new(self.build_token_provider()?),
quota_project_id,
token_provider,
trust_boundary,
}),
})
}
Expand Down Expand Up @@ -462,6 +474,7 @@ where
{
token_provider: T,
quota_project_id: Option<String>,
trust_boundary: Arc<TrustBoundary>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -573,7 +586,9 @@ where
{
async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
let token = self.token_provider.token(extensions).await?;
build_cacheable_headers(&token, &self.quota_project_id)
let trust_boundary_header = self.trust_boundary.header_value();

build_cacheable_headers(&token, &self.quota_project_id, &trust_boundary_header)
}
}

Expand Down Expand Up @@ -652,9 +667,11 @@ mod tests {
let mut mock = MockTokenProvider::new();
mock.expect_token().times(1).return_once(|| Ok(token));

let cache = TokenCache::new(mock);
let sac = ServiceAccountCredentials {
token_provider: TokenCache::new(mock),
token_provider: cache.clone(),
quota_project_id: None,
trust_boundary: Arc::new(TrustBoundary::new(cache, "http://localhost".to_string())),
};

let mut extensions = Extensions::new();
Expand Down Expand Up @@ -694,9 +711,11 @@ mod tests {
let mut mock = MockTokenProvider::new();
mock.expect_token().times(1).return_once(|| Ok(token));

let cache = TokenCache::new(mock);
let sac = ServiceAccountCredentials {
token_provider: TokenCache::new(mock),
token_provider: cache.clone(),
quota_project_id: Some(quota_project.to_string()),
trust_boundary: Arc::new(TrustBoundary::new(cache, "http://localhost".to_string())),
};

let headers = get_headers_from_cache(sac.headers(Extensions::new()).await.unwrap())?;
Expand All @@ -721,9 +740,11 @@ mod tests {
.times(1)
.return_once(|| Err(errors::non_retryable_from_str("fail")));

let cache = TokenCache::new(mock);
let sac = ServiceAccountCredentials {
token_provider: TokenCache::new(mock),
token_provider: cache.clone(),
quota_project_id: None,
trust_boundary: Arc::new(TrustBoundary::new(cache, "http://localhost".to_string())),
};
assert!(sac.headers(Extensions::new()).await.is_err());
}
Expand Down
2 changes: 1 addition & 1 deletion src/auth/src/credentials/user_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ where
{
async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
let token = self.token_provider.token(extensions).await?;
build_cacheable_headers(&token, &self.quota_project_id)
build_cacheable_headers(&token, &self.quota_project_id, &None)
}
}

Expand Down
72 changes: 62 additions & 10 deletions src/auth/src/headers_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::Result;
use crate::credentials::{CacheableResource, QUOTA_PROJECT_KEY};
use crate::errors;
use crate::token::Token;
use crate::trust_boundary::TRUST_BOUNDARY_HEADER;

use http::HeaderMap;
use http::header::{AUTHORIZATION, HeaderName, HeaderValue};
Expand Down Expand Up @@ -55,11 +56,12 @@ const API_KEY_HEADER_KEY: &str = "x-goog-api-key";
pub(crate) fn build_cacheable_headers(
cached_token: &CacheableResource<Token>,
quota_project_id: &Option<String>,
trust_boundary_header: &Option<String>,
) -> Result<CacheableResource<HeaderMap>> {
match cached_token {
CacheableResource::NotModified => Ok(CacheableResource::NotModified),
CacheableResource::New { entity_tag, data } => {
let headers = build_bearer_headers(data, quota_project_id)?;
let headers = build_bearer_headers(data, quota_project_id, trust_boundary_header)?;
Ok(CacheableResource::New {
entity_tag: entity_tag.clone(),
data: headers,
Expand All @@ -72,11 +74,18 @@ pub(crate) fn build_cacheable_headers(
fn build_bearer_headers(
token: &crate::token::Token,
quota_project_id: &Option<String>,
trust_boundary_header: &Option<String>,
) -> Result<HeaderMap> {
build_headers(token, quota_project_id, AUTHORIZATION, |token| {
HeaderValue::from_str(&format!("{} {}", token.token_type, token.token))
.map_err(errors::non_retryable)
})
build_headers(
token,
quota_project_id,
trust_boundary_header,
AUTHORIZATION,
|token| {
HeaderValue::from_str(&format!("{} {}", token.token_type, token.token))
.map_err(errors::non_retryable)
},
)
}

pub(crate) fn build_cacheable_api_key_headers(
Expand All @@ -99,6 +108,7 @@ fn build_api_key_headers(token: &crate::token::Token) -> Result<HeaderMap> {
build_headers(
token,
&None,
&None,
HeaderName::from_static(API_KEY_HEADER_KEY),
|token| HeaderValue::from_str(&token.token).map_err(errors::non_retryable),
)
Expand All @@ -108,6 +118,7 @@ fn build_api_key_headers(token: &crate::token::Token) -> Result<HeaderMap> {
fn build_headers(
token: &crate::token::Token,
quota_project_id: &Option<String>,
trust_boundary_header: &Option<String>,
header_name: HeaderName,
build_header_value: impl FnOnce(&crate::token::Token) -> Result<HeaderValue>,
) -> Result<HeaderMap> {
Expand All @@ -124,6 +135,13 @@ fn build_headers(
);
}

if let Some(trust_boundary) = trust_boundary_header {
header_map.insert(
HeaderName::from_static(TRUST_BOUNDARY_HEADER),
HeaderValue::from_str(trust_boundary).map_err(errors::non_retryable)?,
);
}

Ok(header_map)
}

Expand Down Expand Up @@ -151,7 +169,7 @@ mod tests {
data: token,
};

let result = build_cacheable_headers(&cacheable_token, &None);
let result = build_cacheable_headers(&cacheable_token, &None, &None);

assert!(result.is_ok());
let cached_headers = result.unwrap();
Expand All @@ -173,7 +191,7 @@ mod tests {
fn build_cacheable_headers_basic_not_modified() {
let cacheable_token = CacheableResource::NotModified;

let result = build_cacheable_headers(&cacheable_token, &None);
let result = build_cacheable_headers(&cacheable_token, &None, &None);

assert!(result.is_ok());
let cached_headers = result.unwrap();
Expand All @@ -192,7 +210,7 @@ mod tests {
};

let quota_project_id = Some("test-project-123".to_string());
let result = build_cacheable_headers(&cacheable_token, &quota_project_id);
let result = build_cacheable_headers(&cacheable_token, &quota_project_id, &None);

assert!(result.is_ok());
let cached_headers = result.unwrap();
Expand All @@ -214,11 +232,45 @@ mod tests {
assert_eq!(quota_project, HeaderValue::from_static("test-project-123"));
}

#[test]
fn build_cacheable_headers_with_trust_boundary_success() {
let token = create_test_token("test_token", "Bearer");
let cacheable_token = CacheableResource::New {
entity_tag: EntityTag::default(),
data: token,
};

let trust_boundary = Some("test-trust-boundary".to_string());
let result = build_cacheable_headers(&cacheable_token, &None, &trust_boundary);

assert!(result.is_ok());
let cached_headers = result.unwrap();
let headers = match cached_headers {
CacheableResource::New { data, .. } => data,
CacheableResource::NotModified => unreachable!("expecting new headers"),
};
assert_eq!(headers.len(), 2, "{headers:?}");

let token = headers
.get(HeaderName::from_static("authorization"))
.unwrap();
assert_eq!(token, HeaderValue::from_static("Bearer test_token"));
assert!(token.is_sensitive());

let trust_boundary = headers
.get(HeaderName::from_static(TRUST_BOUNDARY_HEADER))
.unwrap();
assert_eq!(
trust_boundary,
HeaderValue::from_static("test-trust-boundary")
);
}

#[test]
fn build_bearer_headers_different_token_type() {
let token = create_test_token("special_token", "MAC");

let result = build_bearer_headers(&token, &None);
let result = build_bearer_headers(&token, &None, &None);

assert!(result.is_ok());
let headers = result.unwrap();
Expand All @@ -237,7 +289,7 @@ mod tests {
fn build_bearer_headers_invalid_token() {
let token = create_test_token("token with \n invalid chars", "Bearer");

let result = build_bearer_headers(&token, &None);
let result = build_bearer_headers(&token, &None, &None);

assert!(result.is_err());
let error = result.unwrap_err();
Expand Down
Loading
Loading