Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ tokio = { workspace = true, features = ["macros", "rt-multi-thre
tokio-test.workspace = true
url.workspace = true
mutants.workspace = true
static_assertions = { workspace = true }

[features]
default = ["default-idtoken-backend", "default-rustls-provider"]
Expand Down
108 changes: 84 additions & 24 deletions src/auth/src/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,85 @@ use gax::retry_loop_internal::retry_loop;
use gax::retry_policy::{AlwaysRetry, RetryPolicy, RetryPolicyArg, RetryPolicyExt};
use gax::retry_throttler::{AdaptiveThrottler, RetryThrottlerArg, SharedRetryThrottler};
use std::error::Error;
use std::panic::RefUnwindSafe;
use std::sync::{Arc, Mutex};

/// A wrapper to assert `RefUnwindSafe` on the inner type.
///
/// This is necessary because adding [TokenProviderWithRetry] to existing, released auth features
/// caused a breaking change. The containing structs would lose their automatically derived
/// `RefUnwindSafe` implementation because the dynamic trait objects (like `dyn RetryPolicy`)
/// are not `RefUnwindSafe` by default.
///
/// This wrapper solves that by manually implementing `RefUnwindSafe`, allowing us to add
/// retry functionality without triggering a semver-check failure. This is safe because
/// we control the implementation of the inner types and can ensure that they are
/// `RefUnwindSafe`.
#[derive(Debug)]
struct UnwindSafeAdapter<T>(T);

impl<T> RefUnwindSafe for UnwindSafeAdapter<T> {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to do this, I propose we simply implement the trait for TokenProviderWithRetry (or the builder for it).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an option too. I did this because only retry_policy and backoff_policy are not UnwindSafe, but you're right, in the end we are adding a impl RefUnwindSafe thing anyway.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, weird. I had to add both RefUnwindSafe and UnwindSafe explicitly to make it work:

impl RefUnwindSafe for Builder {}
impl UnwindSafe for Builder {}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, weird. I don't see the changes here, I assume you are experimenting locally or something. Let me know when this is ready for review again.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I'm testing in another branch, so I can run semver-check to make sure it looks good. The latest changes are on this one: https://github.com/googleapis/google-cloud-rust/pull/4256/changes#diff-43612f24f6413f8e5304d6af3e912553d7913ceea55c203c549db570b52d0d30

If I don't add both impl, the check fails.


impl<T> From<T> for UnwindSafeAdapter<T> {
fn from(inner: T) -> Self {
Self(inner)
}
}

impl RetryPolicy for UnwindSafeAdapter<Arc<dyn RetryPolicy>> {
fn on_error(
&self,
state: &gax::retry_state::RetryState,
error: gax::error::Error,
) -> gax::retry_result::RetryResult {
self.0.on_error(state, error)
}
fn on_throttle(
&self,
state: &gax::retry_state::RetryState,
error: gax::error::Error,
) -> gax::throttle_result::ThrottleResult {
self.0.on_throttle(state, error)
}
fn remaining_time(&self, state: &gax::retry_state::RetryState) -> Option<std::time::Duration> {
self.0.remaining_time(state)
}
}

impl BackoffPolicy for UnwindSafeAdapter<Arc<dyn BackoffPolicy>> {
fn on_failure(&self, state: &gax::retry_state::RetryState) -> std::time::Duration {
self.0.on_failure(state)
}
}

#[derive(Debug)]
pub(crate) struct TokenProviderWithRetry<T: TokenProvider> {
pub(crate) inner: Arc<T>,
retry_policy: Arc<dyn RetryPolicy>,
backoff_policy: Arc<dyn BackoffPolicy>,
retry_policy: Arc<dyn RetryPolicy + RefUnwindSafe>,
backoff_policy: Arc<dyn BackoffPolicy + RefUnwindSafe>,
retry_throttler: SharedRetryThrottler,
}

#[derive(Debug, Default)]
pub(crate) struct Builder {
retry_policy: Option<RetryPolicyArg>,
backoff_policy: Option<BackoffPolicyArg>,
retry_policy: Option<Arc<dyn RetryPolicy + RefUnwindSafe>>,
backoff_policy: Option<Arc<dyn BackoffPolicy + RefUnwindSafe>>,
retry_throttler: Option<RetryThrottlerArg>,
}

impl Builder {
pub(crate) fn with_retry_policy(mut self, retry_policy: RetryPolicyArg) -> Self {
self.retry_policy = Some(retry_policy);
pub(crate) fn with_retry_policy<P: Into<RetryPolicyArg>>(mut self, retry_policy: P) -> Self {
let inner: Arc<dyn RetryPolicy> = retry_policy.into().into();
self.retry_policy = Some(Arc::new(UnwindSafeAdapter::from(inner)));
self
}

pub(crate) fn with_backoff_policy(mut self, backoff_policy: BackoffPolicyArg) -> Self {
self.backoff_policy = Some(backoff_policy);
pub(crate) fn with_backoff_policy<P: Into<BackoffPolicyArg>>(
mut self,
backoff_policy: P,
) -> Self {
let inner: Arc<dyn BackoffPolicy> = backoff_policy.into().into();
self.backoff_policy = Some(Arc::new(UnwindSafeAdapter::from(inner)));
self
}

Expand All @@ -57,19 +111,19 @@ impl Builder {
}

pub(crate) fn build<T: TokenProvider>(self, token_provider: T) -> TokenProviderWithRetry<T> {
let backoff_policy: Arc<dyn BackoffPolicy> = match self.backoff_policy {
Some(p) => p.into(),
None => Arc::new(ExponentialBackoff::default()),
};
let backoff_policy = self.backoff_policy.unwrap_or_else(|| {
let p: Arc<dyn BackoffPolicy> = Arc::new(ExponentialBackoff::default());
Arc::new(UnwindSafeAdapter::from(p))
});
let retry_throttler: SharedRetryThrottler = match self.retry_throttler {
Some(p) => p.into(),
None => Arc::new(Mutex::new(AdaptiveThrottler::default())),
};

let retry_policy = self
.retry_policy
.unwrap_or_else(|| AlwaysRetry.with_attempt_limit(1).into())
.into();
let retry_policy = self.retry_policy.unwrap_or_else(|| {
let p: Arc<dyn RetryPolicy> = Arc::new(AlwaysRetry.with_attempt_limit(1));
Arc::new(UnwindSafeAdapter::from(p))
});

TokenProviderWithRetry {
inner: Arc::new(token_provider),
Expand Down Expand Up @@ -142,6 +196,7 @@ mod tests {
use gax::retry_state::RetryState;
use gax::retry_throttler::RetryThrottler;
use mockall::{Sequence, mock};
use static_assertions::assert_impl_all;
use std::error::Error;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
Expand Down Expand Up @@ -222,7 +277,7 @@ mod tests {
.return_once(|| Ok(token));

let provider = Builder::default()
.with_retry_policy(AuthRetryPolicy { max_attempts: 2 }.into())
.with_retry_policy(AuthRetryPolicy { max_attempts: 2 })
.build(mock_provider);

let token = provider.token().await.unwrap();
Expand Down Expand Up @@ -253,7 +308,7 @@ mod tests {
});

let provider = Builder::default()
.with_retry_policy(AuthRetryPolicy { max_attempts: 2 }.into())
.with_retry_policy(AuthRetryPolicy { max_attempts: 2 })
.build(mock_provider);

let token = provider.token().await.unwrap();
Expand All @@ -269,7 +324,7 @@ mod tests {
.returning(|| Err(CredentialsError::from_msg(true, "transient error")));

let provider = Builder::default()
.with_retry_policy(AuthRetryPolicy { max_attempts: 2 }.into())
.with_retry_policy(AuthRetryPolicy { max_attempts: 2 })
.build(mock_provider);

let error = provider.token().await.unwrap_err();
Expand All @@ -288,7 +343,7 @@ mod tests {
.returning(|| Err(CredentialsError::from_msg(false, "non transient error")));

let provider = Builder::default()
.with_retry_policy(AuthRetryPolicy { max_attempts: 2 }.into())
.with_retry_policy(AuthRetryPolicy { max_attempts: 2 })
.build(mock_provider);

let error = provider.token().await.unwrap_err();
Expand Down Expand Up @@ -379,8 +434,8 @@ mod tests {
let backoff_policy = TestBackoffPolicy::default();
let retry_throttler = AdaptiveThrottler::new(4.0).unwrap();
builder = builder
.with_retry_policy(retry_policy.into())
.with_backoff_policy(backoff_policy.into())
.with_retry_policy(retry_policy)
.with_backoff_policy(backoff_policy)
.with_retry_throttler(retry_throttler.into());
}

Expand Down Expand Up @@ -448,8 +503,8 @@ mod tests {

// 4. Build and run
let provider = Builder::default()
.with_retry_policy(retry_policy.into())
.with_backoff_policy(backoff_policy.into())
.with_retry_policy(retry_policy)
.with_backoff_policy(backoff_policy)
.with_retry_throttler(retry_throttler.into())
.build(mock_provider);

Expand Down Expand Up @@ -479,4 +534,9 @@ mod tests {
original_error_string
);
}

#[test]
fn test_unwind_safe() {
assert_impl_all!(Builder: std::panic::UnwindSafe, std::panic::RefUnwindSafe);
}
}
Loading