From aedbe7f8cfe957a0d5281989707a3931bc5f9a4f Mon Sep 17 00:00:00 2001 From: Alvaro Viebrantz Date: Mon, 12 Jan 2026 19:25:43 +0000 Subject: [PATCH] impl(auth): add retry to user account idtoken provider --- .../src/credentials/idtoken/user_account.rs | 172 ++++++++++++++++-- 1 file changed, 152 insertions(+), 20 deletions(-) diff --git a/src/auth/src/credentials/idtoken/user_account.rs b/src/auth/src/credentials/idtoken/user_account.rs index 96e1f01401..4a91855148 100644 --- a/src/auth/src/credentials/idtoken/user_account.rs +++ b/src/auth/src/credentials/idtoken/user_account.rs @@ -59,6 +59,7 @@ use crate::build_errors::Error as BuilderError; use crate::credentials::CacheableResource; use crate::credentials::user_account::UserTokenProvider; +use crate::retry::Builder as RetryTokenProviderBuilder; use crate::token::CachedTokenProvider; use crate::token_cache::TokenCache; use crate::{ @@ -69,7 +70,10 @@ use crate::{ }, }; use async_trait::async_trait; +use gax::backoff_policy::BackoffPolicyArg; use gax::error::CredentialsError; +use gax::retry_policy::RetryPolicyArg; +use gax::retry_throttler::RetryThrottlerArg; use http::Extensions; use serde_json::Value; use std::sync::Arc; @@ -117,6 +121,7 @@ where pub struct Builder { authorized_user: Value, token_uri: Option, + retry_builder: RetryTokenProviderBuilder, } impl Builder { @@ -131,6 +136,7 @@ impl Builder { Self { authorized_user, token_uri: None, + retry_builder: RetryTokenProviderBuilder::default(), } } @@ -155,12 +161,87 @@ impl Builder { self } - fn build_token_provider(self) -> BuildResult { - let authorized_user = serde_json::from_value::(self.authorized_user) - .map_err(BuilderError::parsing)?; + /// Configure the retry policy for fetching tokens. + /// + /// The retry policy controls how to handle retries, and sets limits on + /// the number of attempts or the total time spent retrying. + /// + /// # Example + /// + /// ``` + /// # use google_cloud_auth::credentials::idtoken; + /// # use serde_json::json; + /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt}; + /// + /// let authorized_user = json!({ /* add details here */ }); + /// + /// let credentials = idtoken::user_account::Builder::new(authorized_user) + /// .with_retry_policy(AlwaysRetry.with_attempt_limit(3)) + /// .build(); + /// ``` + pub fn with_retry_policy>(mut self, v: V) -> Self { + self.retry_builder = self.retry_builder.with_retry_policy(v.into()); + self + } + + /// Configure the retry backoff policy. + /// + /// The backoff policy controls how long to wait in between retry attempts. + /// + /// # Example + /// + /// ``` + /// # use google_cloud_auth::credentials::idtoken; + /// # use serde_json::json; + /// use gax::exponential_backoff::ExponentialBackoff; + /// + /// let authorized_user = json!({ /* add details here */ }); + /// + /// let credentials = idtoken::user_account::Builder::new(authorized_user) + /// .with_backoff_policy(ExponentialBackoff::default()) + /// .build(); + /// ``` + pub fn with_backoff_policy>(mut self, v: V) -> Self { + self.retry_builder = self.retry_builder.with_backoff_policy(v.into()); + self + } + + /// Configure the retry throttler. + /// + /// Advanced applications may want to configure a retry throttler to + /// [Address Cascading Failures] and when [Handling Overload] conditions. + /// The authentication library throttles its retry loop, using a policy to + /// control the throttling algorithm. Use this method to fine tune or + /// customize the default retry throttler. + /// + /// [Handling Overload]: https://sre.google/sre-book/handling-overload/ + /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/ + /// + /// # Example + /// + /// ``` + /// # use google_cloud_auth::credentials::idtoken; + /// # use serde_json::json; + /// use gax::retry_throttler::AdaptiveThrottler; + /// + /// let authorized_user = json!({ /* add details here */ }); + /// + /// let credentials = idtoken::user_account::Builder::new(authorized_user) + /// .with_retry_throttler(AdaptiveThrottler::default()) + /// .build(); + /// ``` + pub fn with_retry_throttler>(mut self, v: V) -> Self { + self.retry_builder = self.retry_builder.with_retry_throttler(v.into()); + self + } + + fn build_token_provider(&self) -> BuildResult { + let authorized_user = + serde_json::from_value::(self.authorized_user.clone()) + .map_err(BuilderError::parsing)?; Ok(UserTokenProvider::new_id_token_provider( authorized_user, - self.token_uri, + self.token_uri.clone(), )) } @@ -178,8 +259,11 @@ impl Builder { /// /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials pub fn build(self) -> BuildResult { + let provider = self.build_token_provider()?; + let provider = self.retry_builder.build(provider); + let creds = UserAccountCredentials { - token_provider: TokenCache::new(self.build_token_provider()?), + token_provider: TokenCache::new(provider), }; Ok(IDTokenCredentials { inner: Arc::new(creds), @@ -189,12 +273,18 @@ impl Builder { #[cfg(test)] mod tests { + use std::error::Error; + use super::*; - use crate::credentials::tests::find_source_error; + use crate::credentials::tests::{ + find_source_error, get_mock_auth_retry_policy, get_mock_backoff_policy, + get_mock_retry_throttler, + }; use crate::credentials::user_account::{ Oauth2RefreshRequest, Oauth2RefreshResponse, RefreshGrantType, }; use http::StatusCode; + use httptest::cycle; use httptest::matchers::{all_of, json_decoded, request}; use httptest::responders::{json_encoded, status_code}; use httptest::{Expectation, Server}; @@ -271,8 +361,10 @@ mod tests { let creds = Builder::new(authorized_user).build()?; let err = creds.id_token().await.unwrap_err(); assert!(!err.is_transient()); + let source = err.source().unwrap(); assert!( - err.to_string() + source + .to_string() .contains("can obtain an id token only when authenticated through gcloud") ); Ok(()) @@ -292,40 +384,80 @@ mod tests { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn id_token_retryable_error() -> TestResult { + async fn id_token_nonretryable_error() -> TestResult { let server = Server::run(); server - .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503))); + .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401))); let authorized_user = authorized_user_json(server.url("/token").to_string()); let creds = Builder::new(authorized_user).build()?; let err = creds.id_token().await.unwrap_err(); - assert!(err.is_transient()); + assert!(!err.is_transient()); let source = find_source_error::(&err); assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)), + matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)), "{err:?}" ); Ok(()) } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn id_token_nonretryable_error() -> TestResult { + async fn test_user_account_id_token_retries_on_transient_failures() -> TestResult { let server = Server::run(); - server - .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401))); + server.expect( + Expectation::matching(request::path("/token")) + .times(3) + .respond_with(status_code(503)), + ); let authorized_user = authorized_user_json(server.url("/token").to_string()); - let creds = Builder::new(authorized_user).build()?; - let err = creds.id_token().await.unwrap_err(); + let credentials = Builder::new(authorized_user) + .with_retry_policy(get_mock_auth_retry_policy(3)) + .with_backoff_policy(get_mock_backoff_policy()) + .with_retry_throttler(get_mock_retry_throttler()) + .build()?; + + let err = credentials.id_token().await.unwrap_err(); assert!(!err.is_transient()); - let source = find_source_error::(&err); - assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)), - "{err:?}" + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_user_account_id_token_retries_for_success() -> TestResult { + let server = Server::run(); + let response = Oauth2RefreshResponse { + access_token: "test-access-token".to_string(), + id_token: Some("test-id-token".to_string()), + expires_in: Some(3600), + refresh_token: Some("test-refresh-token".to_string()), + scope: None, + token_type: "Bearer".to_string(), + }; + + server.expect( + Expectation::matching(request::path("/token")) + .times(3) + .respond_with(cycle![ + status_code(503).body("try-again"), + status_code(503).body("try-again"), + status_code(200) + .append_header("Content-Type", "application/json") + .body(serde_json::to_string(&response).unwrap()), + ]), ); + + let authorized_user = authorized_user_json(server.url("/token").to_string()); + let credentials = Builder::new(authorized_user) + .with_retry_policy(get_mock_auth_retry_policy(3)) + .with_backoff_policy(get_mock_backoff_policy()) + .with_retry_throttler(get_mock_retry_throttler()) + .build()?; + + let id_token = credentials.id_token().await.unwrap(); + assert_eq!(id_token, "test-id-token"); + Ok(()) } }