diff --git a/src/auth/src/credentials/idtoken/mds.rs b/src/auth/src/credentials/idtoken/mds.rs index d89c29ca6a..94d4c48b26 100644 --- a/src/auth/src/credentials/idtoken/mds.rs +++ b/src/auth/src/credentials/idtoken/mds.rs @@ -68,6 +68,7 @@ use crate::Result; use crate::credentials::CacheableResource; use crate::errors::CredentialsError; use crate::mds::client::Client as MDSClient; +use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry}; use crate::token::{CachedTokenProvider, Token, TokenProvider}; use crate::token_cache::TokenCache; use crate::{ @@ -76,6 +77,9 @@ use crate::{ credentials::idtoken::{IDTokenCredentials, parse_id_token_from_str}, }; use async_trait::async_trait; +use gax::backoff_policy::BackoffPolicyArg; +use gax::retry_policy::RetryPolicyArg; +use gax::retry_throttler::RetryThrottlerArg; use http::Extensions; use std::sync::Arc; @@ -131,6 +135,7 @@ pub struct Builder { pub(crate) format: Option, licenses: Option, target_audience: String, + retry_builder: RetryTokenProviderBuilder, } impl Builder { @@ -145,6 +150,7 @@ impl Builder { endpoint: None, licenses: None, target_audience: target_audience.into(), + retry_builder: RetryTokenProviderBuilder::default(), } } @@ -218,15 +224,83 @@ impl Builder { self } - fn build_token_provider(self) -> MDSTokenProvider { - let client = MDSClient::new(self.endpoint); + /// Configure the retry policy for fetching tokens. + /// + /// The retry policy controls how to handle retries, and sets limits on + /// the number of attempts or the total time spent retrying. + /// + /// # Example + /// + /// ```no_run + /// # use google_cloud_auth::credentials::idtoken; + /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt}; + /// + /// let audience = "https://my-service.a.run.app"; + /// let credentials = idtoken::mds::Builder::new(audience) + /// .with_retry_policy(AlwaysRetry.with_attempt_limit(3)) + /// .build(); + /// ``` + pub fn with_retry_policy>(mut self, v: V) -> Self { + self.retry_builder = self.retry_builder.with_retry_policy(v.into()); + self + } - MDSTokenProvider { + /// Configure the retry backoff policy. + /// + /// The backoff policy controls how long to wait in between retry attempts. + /// + /// # Example + /// + /// ```no_run + /// # use google_cloud_auth::credentials::idtoken; + /// use gax::exponential_backoff::ExponentialBackoff; + /// + /// let audience = "https://my-service.a.run.app"; + /// let credentials = idtoken::mds::Builder::new(audience) + /// .with_backoff_policy(ExponentialBackoff::default()) + /// .build(); + /// ``` + pub fn with_backoff_policy>(mut self, v: V) -> Self { + self.retry_builder = self.retry_builder.with_backoff_policy(v.into()); + self + } + + /// Configure the retry throttler. + /// + /// Advanced applications may want to configure a retry throttler to + /// [Address Cascading Failures] and when [Handling Overload] conditions. + /// The authentication library throttles its retry loop, using a policy to + /// control the throttling algorithm. Use this method to fine tune or + /// customize the default retry throttler. + /// + /// [Handling Overload]: https://sre.google/sre-book/handling-overload/ + /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/ + /// + /// # Example + /// + /// ```no_run + /// # use google_cloud_auth::credentials::idtoken; + /// use gax::retry_throttler::AdaptiveThrottler; + /// + /// let audience = "https://my-service.a.run.app"; + /// let credentials = idtoken::mds::Builder::new(audience) + /// .with_retry_throttler(AdaptiveThrottler::default()) + /// .build(); + /// ``` + pub fn with_retry_throttler>(mut self, v: V) -> Self { + self.retry_builder = self.retry_builder.with_retry_throttler(v.into()); + self + } + + fn build_token_provider(self) -> TokenProviderWithRetry { + let client = MDSClient::new(self.endpoint); + let tp = MDSTokenProvider { format: self.format, licenses: self.licenses, client, target_audience: self.target_audience, - } + }; + self.retry_builder.build(tp) } /// Returns an [`IDTokenCredentials`] instance with the configured @@ -266,8 +340,12 @@ impl TokenProvider for MDSTokenProvider { mod tests { use super::*; use crate::credentials::idtoken::tests::generate_test_id_token; - use crate::credentials::tests::find_source_error; + use crate::credentials::tests::{ + find_source_error, get_mock_auth_retry_policy, get_mock_backoff_policy, + get_mock_retry_throttler, + }; use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI}; + use httptest::cycle; use httptest::matchers::{all_of, contains, request, url_decoded}; use httptest::responders::status_code; use httptest::{Expectation, Server}; @@ -278,6 +356,70 @@ mod tests { type TestResult = anyhow::Result<()>; + #[tokio::test] + #[parallel] + async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult { + let server = Server::run(); + let audience = "test-audience"; + server.expect( + Expectation::matching(all_of![ + request::path(format!("{MDS_DEFAULT_URI}/identity")), + request::query(url_decoded(contains(("audience", audience)))), + ]) + .times(1) + .respond_with(status_code(401)), + ); + + let creds = Builder::new(audience) + .with_endpoint(format!("http://{}", server.addr())) + .with_retry_policy(get_mock_auth_retry_policy(3)) + .with_backoff_policy(get_mock_backoff_policy()) + .with_retry_throttler(get_mock_retry_throttler()) + .build()?; + + let err = creds.id_token().await.unwrap_err(); + let source = find_source_error::(&err); + assert!( + matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)), + "{err:?}" + ); + + Ok(()) + } + + #[tokio::test] + #[parallel] + async fn test_mds_retries_for_success() -> TestResult { + let server = Server::run(); + let audience = "test-audience"; + let token_string = generate_test_id_token(audience); + + server.expect( + Expectation::matching(all_of![ + request::path(format!("{MDS_DEFAULT_URI}/identity")), + request::query(url_decoded(contains(("audience", audience)))), + ]) + .times(3) + .respond_with(cycle![ + status_code(503).body("try-again"), + status_code(503).body("try-again"), + status_code(200).body(token_string.clone()), + ]), + ); + + let creds = Builder::new(audience) + .with_endpoint(format!("http://{}", server.addr())) + .with_retry_policy(get_mock_auth_retry_policy(3)) + .with_backoff_policy(get_mock_backoff_policy()) + .with_retry_throttler(get_mock_retry_throttler()) + .build()?; + + let id_token = creds.id_token().await?; + assert_eq!(id_token, token_string); + + Ok(()) + } + #[tokio::test] #[test_case(Format::Standard)] #[test_case(Format::Full)]