Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 147 additions & 5 deletions src/auth/src/credentials/idtoken/mds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ use crate::Result;
use crate::credentials::CacheableResource;
use crate::errors::CredentialsError;
use crate::mds::client::Client as MDSClient;
use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
use crate::token::{CachedTokenProvider, Token, TokenProvider};
use crate::token_cache::TokenCache;
use crate::{
Expand All @@ -76,6 +77,9 @@ use crate::{
credentials::idtoken::{IDTokenCredentials, parse_id_token_from_str},
};
use async_trait::async_trait;
use gax::backoff_policy::BackoffPolicyArg;
use gax::retry_policy::RetryPolicyArg;
use gax::retry_throttler::RetryThrottlerArg;
use http::Extensions;
use std::sync::Arc;

Expand Down Expand Up @@ -131,6 +135,7 @@ pub struct Builder {
pub(crate) format: Option<Format>,
licenses: Option<String>,
target_audience: String,
retry_builder: RetryTokenProviderBuilder,
}

impl Builder {
Expand All @@ -145,6 +150,7 @@ impl Builder {
endpoint: None,
licenses: None,
target_audience: target_audience.into(),
retry_builder: RetryTokenProviderBuilder::default(),
}
}

Expand Down Expand Up @@ -218,15 +224,83 @@ impl Builder {
self
}

fn build_token_provider(self) -> MDSTokenProvider {
let client = MDSClient::new(self.endpoint);
/// Configure the retry policy for fetching tokens.
///
/// The retry policy controls how to handle retries, and sets limits on
/// the number of attempts or the total time spent retrying.
///
/// # Example
///
/// ```no_run
/// # use google_cloud_auth::credentials::idtoken;
/// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
///
/// let audience = "https://my-service.a.run.app";
/// let credentials = idtoken::mds::Builder::new(audience)
/// .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
/// .build();
/// ```
pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_retry_policy(v.into());
self
}

MDSTokenProvider {
/// Configure the retry backoff policy.
///
/// The backoff policy controls how long to wait in between retry attempts.
///
/// # Example
///
/// ```no_run
/// # use google_cloud_auth::credentials::idtoken;
/// use gax::exponential_backoff::ExponentialBackoff;
///
/// let audience = "https://my-service.a.run.app";
/// let credentials = idtoken::mds::Builder::new(audience)
/// .with_backoff_policy(ExponentialBackoff::default())
/// .build();
/// ```
pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
self
}

/// Configure the retry throttler.
///
/// Advanced applications may want to configure a retry throttler to
/// [Address Cascading Failures] and when [Handling Overload] conditions.
/// The authentication library throttles its retry loop, using a policy to
/// control the throttling algorithm. Use this method to fine tune or
/// customize the default retry throttler.
///
/// [Handling Overload]: https://sre.google/sre-book/handling-overload/
/// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
///
/// # Example
///
/// ```no_run
/// # use google_cloud_auth::credentials::idtoken;
/// use gax::retry_throttler::AdaptiveThrottler;
///
/// let audience = "https://my-service.a.run.app";
/// let credentials = idtoken::mds::Builder::new(audience)
/// .with_retry_throttler(AdaptiveThrottler::default())
/// .build();
/// ```
pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
self
}

fn build_token_provider(self) -> TokenProviderWithRetry<MDSTokenProvider> {
let client = MDSClient::new(self.endpoint);
let tp = MDSTokenProvider {
format: self.format,
licenses: self.licenses,
client,
target_audience: self.target_audience,
}
};
self.retry_builder.build(tp)
}

/// Returns an [`IDTokenCredentials`] instance with the configured
Expand Down Expand Up @@ -266,8 +340,12 @@ impl TokenProvider for MDSTokenProvider {
mod tests {
use super::*;
use crate::credentials::idtoken::tests::generate_test_id_token;
use crate::credentials::tests::find_source_error;
use crate::credentials::tests::{
find_source_error, get_mock_auth_retry_policy, get_mock_backoff_policy,
get_mock_retry_throttler,
};
use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI};
use httptest::cycle;
use httptest::matchers::{all_of, contains, request, url_decoded};
use httptest::responders::status_code;
use httptest::{Expectation, Server};
Expand All @@ -278,6 +356,70 @@ mod tests {

type TestResult = anyhow::Result<()>;

#[tokio::test]
#[parallel]
async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
let server = Server::run();
let audience = "test-audience";
server.expect(
Expectation::matching(all_of![
request::path(format!("{MDS_DEFAULT_URI}/identity")),
request::query(url_decoded(contains(("audience", audience)))),
])
.times(1)
.respond_with(status_code(401)),
);

let creds = Builder::new(audience)
.with_endpoint(format!("http://{}", server.addr()))
.with_retry_policy(get_mock_auth_retry_policy(3))
.with_backoff_policy(get_mock_backoff_policy())
.with_retry_throttler(get_mock_retry_throttler())
.build()?;

let err = creds.id_token().await.unwrap_err();
let source = find_source_error::<reqwest::Error>(&err);
assert!(
matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
"{err:?}"
);

Ok(())
}

#[tokio::test]
#[parallel]
async fn test_mds_retries_for_success() -> TestResult {
let server = Server::run();
let audience = "test-audience";
let token_string = generate_test_id_token(audience);

server.expect(
Expectation::matching(all_of![
request::path(format!("{MDS_DEFAULT_URI}/identity")),
request::query(url_decoded(contains(("audience", audience)))),
])
.times(3)
.respond_with(cycle![
status_code(503).body("try-again"),
status_code(503).body("try-again"),
status_code(200).body(token_string.clone()),
]),
);

let creds = Builder::new(audience)
.with_endpoint(format!("http://{}", server.addr()))
.with_retry_policy(get_mock_auth_retry_policy(3))
.with_backoff_policy(get_mock_backoff_policy())
.with_retry_throttler(get_mock_retry_throttler())
.build()?;

let id_token = creds.id_token().await?;
assert_eq!(id_token, token_string);

Ok(())
}

#[tokio::test]
#[test_case(Format::Standard)]
#[test_case(Format::Full)]
Expand Down
Loading