Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 152 additions & 20 deletions src/auth/src/credentials/idtoken/user_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
use crate::build_errors::Error as BuilderError;
use crate::credentials::CacheableResource;
use crate::credentials::user_account::UserTokenProvider;
use crate::retry::Builder as RetryTokenProviderBuilder;
use crate::token::CachedTokenProvider;
use crate::token_cache::TokenCache;
use crate::{
Expand All @@ -69,7 +70,10 @@ use crate::{
},
};
use async_trait::async_trait;
use gax::backoff_policy::BackoffPolicyArg;
use gax::error::CredentialsError;
use gax::retry_policy::RetryPolicyArg;
use gax::retry_throttler::RetryThrottlerArg;
use http::Extensions;
use serde_json::Value;
use std::sync::Arc;
Expand Down Expand Up @@ -117,6 +121,7 @@ where
pub struct Builder {
authorized_user: Value,
token_uri: Option<String>,
retry_builder: RetryTokenProviderBuilder,
}

impl Builder {
Expand All @@ -131,6 +136,7 @@ impl Builder {
Self {
authorized_user,
token_uri: None,
retry_builder: RetryTokenProviderBuilder::default(),
}
}

Expand All @@ -155,12 +161,87 @@ impl Builder {
self
}

fn build_token_provider(self) -> BuildResult<UserTokenProvider> {
let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
.map_err(BuilderError::parsing)?;
/// Configure the retry policy for fetching tokens.
///
/// The retry policy controls how to handle retries, and sets limits on
/// the number of attempts or the total time spent retrying.
///
/// # Example
///
/// ```
/// # use google_cloud_auth::credentials::idtoken;
/// # use serde_json::json;
/// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
///
/// let authorized_user = json!({ /* add details here */ });
///
/// let credentials = idtoken::user_account::Builder::new(authorized_user)
/// .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
/// .build();
/// ```
pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_retry_policy(v.into());
self
}

/// Configure the retry backoff policy.
///
/// The backoff policy controls how long to wait in between retry attempts.
///
/// # Example
///
/// ```
/// # use google_cloud_auth::credentials::idtoken;
/// # use serde_json::json;
/// use gax::exponential_backoff::ExponentialBackoff;
///
/// let authorized_user = json!({ /* add details here */ });
///
/// let credentials = idtoken::user_account::Builder::new(authorized_user)
/// .with_backoff_policy(ExponentialBackoff::default())
/// .build();
/// ```
pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
self
}

/// Configure the retry throttler.
///
/// Advanced applications may want to configure a retry throttler to
/// [Address Cascading Failures] and when [Handling Overload] conditions.
/// The authentication library throttles its retry loop, using a policy to
/// control the throttling algorithm. Use this method to fine tune or
/// customize the default retry throttler.
///
/// [Handling Overload]: https://sre.google/sre-book/handling-overload/
/// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
///
/// # Example
///
/// ```
/// # use google_cloud_auth::credentials::idtoken;
/// # use serde_json::json;
/// use gax::retry_throttler::AdaptiveThrottler;
///
/// let authorized_user = json!({ /* add details here */ });
///
/// let credentials = idtoken::user_account::Builder::new(authorized_user)
/// .with_retry_throttler(AdaptiveThrottler::default())
/// .build();
/// ```
pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
self
}

fn build_token_provider(&self) -> BuildResult<UserTokenProvider> {
let authorized_user =
serde_json::from_value::<AuthorizedUser>(self.authorized_user.clone())
.map_err(BuilderError::parsing)?;
Ok(UserTokenProvider::new_id_token_provider(
authorized_user,
self.token_uri,
self.token_uri.clone(),
))
}

Expand All @@ -178,8 +259,11 @@ impl Builder {
///
/// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
pub fn build(self) -> BuildResult<IDTokenCredentials> {
let provider = self.build_token_provider()?;
let provider = self.retry_builder.build(provider);

let creds = UserAccountCredentials {
token_provider: TokenCache::new(self.build_token_provider()?),
token_provider: TokenCache::new(provider),
};
Ok(IDTokenCredentials {
inner: Arc::new(creds),
Expand All @@ -189,13 +273,19 @@ impl Builder {

#[cfg(test)]
mod tests {
use std::error::Error;

use super::*;
use crate::credentials::idtoken::tests::generate_test_id_token;
use crate::credentials::tests::find_source_error;
use crate::credentials::tests::{
find_source_error, get_mock_auth_retry_policy, get_mock_backoff_policy,
get_mock_retry_throttler,
};
use crate::credentials::user_account::{
Oauth2RefreshRequest, Oauth2RefreshResponse, RefreshGrantType,
};
use http::StatusCode;
use httptest::cycle;
use httptest::matchers::{all_of, json_decoded, request};
use httptest::responders::{json_encoded, status_code};
use httptest::{Expectation, Server};
Expand Down Expand Up @@ -274,8 +364,10 @@ mod tests {
let creds = Builder::new(authorized_user).build()?;
let err = creds.id_token().await.unwrap_err();
assert!(!err.is_transient());
let source = err.source().unwrap();
assert!(
err.to_string()
source
.to_string()
.contains("can obtain an id token only when authenticated through gcloud")
);
Ok(())
Expand All @@ -295,40 +387,80 @@ mod tests {
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn id_token_retryable_error() -> TestResult {
async fn id_token_nonretryable_error() -> TestResult {
let server = Server::run();
server
.expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
.expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));

let authorized_user = authorized_user_json(server.url("/token").to_string());
let creds = Builder::new(authorized_user).build()?;
let err = creds.id_token().await.unwrap_err();
assert!(err.is_transient());
assert!(!err.is_transient());

let source = find_source_error::<reqwest::Error>(&err);
assert!(
matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
"{err:?}"
);
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn id_token_nonretryable_error() -> TestResult {
async fn test_user_account_id_token_retries_on_transient_failures() -> TestResult {
let server = Server::run();
server
.expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
server.expect(
Expectation::matching(request::path("/token"))
.times(3)
.respond_with(status_code(503)),
);

let authorized_user = authorized_user_json(server.url("/token").to_string());
let creds = Builder::new(authorized_user).build()?;
let err = creds.id_token().await.unwrap_err();
let credentials = Builder::new(authorized_user)
.with_retry_policy(get_mock_auth_retry_policy(3))
.with_backoff_policy(get_mock_backoff_policy())
.with_retry_throttler(get_mock_retry_throttler())
.build()?;

let err = credentials.id_token().await.unwrap_err();
assert!(!err.is_transient());

let source = find_source_error::<reqwest::Error>(&err);
assert!(
matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
"{err:?}"
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_user_account_id_token_retries_for_success() -> TestResult {
let server = Server::run();
let response = Oauth2RefreshResponse {
access_token: "test-access-token".to_string(),
id_token: Some("test-id-token".to_string()),
expires_in: Some(3600),
refresh_token: Some("test-refresh-token".to_string()),
scope: None,
token_type: "Bearer".to_string(),
};

server.expect(
Expectation::matching(request::path("/token"))
.times(3)
.respond_with(cycle![
status_code(503).body("try-again"),
status_code(503).body("try-again"),
status_code(200)
.append_header("Content-Type", "application/json")
.body(serde_json::to_string(&response).unwrap()),
]),
);

let authorized_user = authorized_user_json(server.url("/token").to_string());
let credentials = Builder::new(authorized_user)
.with_retry_policy(get_mock_auth_retry_policy(3))
.with_backoff_policy(get_mock_backoff_policy())
.with_retry_throttler(get_mock_retry_throttler())
.build()?;

let id_token = credentials.id_token().await.unwrap();
assert_eq!(id_token, "test-id-token");

Ok(())
}

Expand Down
Loading