From 67d19e8a3962fc111390ddd4f5777c563be40ac0 Mon Sep 17 00:00:00 2001 From: "Mathias V. Nielsen" <1547127+math280h@users.noreply.github.com> Date: Wed, 25 Feb 2026 21:33:53 -0500 Subject: [PATCH] feat: Shared HTTP Client --- CONTRIBUTING.md | 19 +++ Cargo.lock | 15 ++ Cargo.toml | 1 + README.md | 1 - crates/http/Cargo.toml | 13 ++ crates/http/src/lib.rs | 231 ++++++++++++++++++++++++++ crates/osv/Cargo.toml | 1 + crates/osv/src/lib.rs | 40 ++--- crates/registry/cargo/Cargo.toml | 1 + crates/registry/cargo/src/registry.rs | 109 +++++------- crates/registry/npm/Cargo.toml | 1 + crates/registry/npm/src/registry.rs | 167 ++++++------------- crates/registry/pypi/Cargo.toml | 1 + crates/registry/pypi/src/registry.rs | 92 +++------- 14 files changed, 410 insertions(+), 282 deletions(-) create mode 100644 crates/http/Cargo.toml create mode 100644 crates/http/src/lib.rs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e53d9ee..2f1512d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,6 +20,25 @@ cargo test - Prefer explicit failure over silent fallback when checks/registry calls fail. - Keep changes focused; avoid unrelated refactors in feature PRs. +## Shared HTTP Utilities + +When adding or updating registry/advisory HTTP calls, use `safe-pkgs-registry-http` (crate path: +`crates/http`) instead of +open-coding per-crate `reqwest` request/retry/error logic. + +Use: +- `build_http_client()` for a preconfigured client (timeouts + user-agent). +- `send_with_retry(...)` for retry/backoff and `429`/`Retry-After` handling. +- `map_status_error(...)` for consistent status -> `RegistryError::Transport`. +- `parse_json(...)` for consistent JSON parse -> `RegistryError::InvalidResponse`. + +Avoid: +- direct `.send().await` + custom retry loops in registry crates +- ad-hoc user-agent headers per request +- hand-written status/error-mapping strings duplicated across crates + +Default user-agent is `safe-pkgs/`. Override only via `SAFE_PKGS_HTTP_USER_AGENT` when needed. + ## Add a New Registry ### 1) Create a new crate diff --git a/Cargo.lock b/Cargo.lock index 7b54d44..36b8a22 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1397,6 +1397,7 @@ dependencies = [ "reqwest", "safe-pkgs-core", "safe-pkgs-osv", + "safe-pkgs-registry-http", "serde", "tokio", "toml", @@ -1490,6 +1491,7 @@ dependencies = [ "reqwest", "safe-pkgs-core", "safe-pkgs-osv", + "safe-pkgs-registry-http", "semver", "serde", "serde_json", @@ -1503,6 +1505,7 @@ version = "0.2.0" dependencies = [ "reqwest", "safe-pkgs-core", + "safe-pkgs-registry-http", "serde", "tokio", "wiremock", @@ -1517,12 +1520,24 @@ dependencies = [ "reqwest", "safe-pkgs-core", "safe-pkgs-osv", + "safe-pkgs-registry-http", "serde", "tokio", "toml", "wiremock", ] +[[package]] +name = "safe-pkgs-registry-http" +version = "0.2.0" +dependencies = [ + "reqwest", + "safe-pkgs-core", + "serde", + "tokio", + "wiremock", +] + [[package]] name = "schemars" version = "1.2.1" diff --git a/Cargo.toml b/Cargo.toml index 3d2e4d4..1984b5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ ".", "crates/core", + "crates/http", "crates/osv", "crates/registry/*", "crates/checks/*", diff --git a/README.md b/README.md index 6ca38ed..b4c0d5a 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,6 @@ Prioritized planned work: ### Now -- [ ] Shared registry HTTP utilities (retry/backoff/rate-limit handling/user-agent/error mapping) - [ ] Transitive dependency path visibility in lockfile audits - [ ] Dependency confusion defenses for internal/private package names - [ ] Policy simulation mode (`what-if`) without enforcement diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml new file mode 100644 index 0000000..383efbc --- /dev/null +++ b/crates/http/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "safe-pkgs-registry-http" +version.workspace = true +edition.workspace = true + +[dependencies] +reqwest.workspace = true +serde.workspace = true +tokio.workspace = true +safe-pkgs-core = { path = "../core" } + +[dev-dependencies] +wiremock.workspace = true diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs new file mode 100644 index 0000000..d7c23a3 --- /dev/null +++ b/crates/http/src/lib.rs @@ -0,0 +1,231 @@ +use reqwest::{Client, RequestBuilder, Response, StatusCode, header::HeaderMap}; +use safe_pkgs_core::RegistryError; +use serde::de::DeserializeOwned; +use std::time::Duration; + +const DEFAULT_MAX_ATTEMPTS: u8 = 3; +const DEFAULT_INITIAL_BACKOFF_MILLIS: u64 = 250; +const DEFAULT_MAX_BACKOFF_SECS: u64 = 5; +const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 5; +const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 20; + +pub const DEFAULT_USER_AGENT: &str = concat!("safe-pkgs/", env!("CARGO_PKG_VERSION")); + +#[derive(Debug, Clone, Copy)] +pub struct RetryPolicy { + pub max_attempts: u8, + pub initial_backoff: Duration, + pub max_backoff: Duration, +} + +impl Default for RetryPolicy { + fn default() -> Self { + Self { + max_attempts: DEFAULT_MAX_ATTEMPTS, + initial_backoff: Duration::from_millis(DEFAULT_INITIAL_BACKOFF_MILLIS), + max_backoff: Duration::from_secs(DEFAULT_MAX_BACKOFF_SECS), + } + } +} + +pub fn build_http_client() -> Client { + let user_agent = std::env::var("SAFE_PKGS_HTTP_USER_AGENT") + .ok() + .filter(|value| !value.trim().is_empty()) + .unwrap_or_else(|| DEFAULT_USER_AGENT.to_string()); + + Client::builder() + .user_agent(user_agent) + .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS)) + .timeout(Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)) + .build() + .unwrap_or_else(|_| Client::new()) +} + +pub async fn send_with_retry( + mut build_request: F, + operation: &str, + policy: RetryPolicy, +) -> Result +where + F: FnMut() -> RequestBuilder, +{ + let max_attempts = policy.max_attempts.max(1); + let mut attempt = 1u8; + loop { + let response = build_request().send().await; + + match response { + Ok(response) => { + if attempt < max_attempts && should_retry_status(response.status()) { + let delay = compute_retry_delay( + attempt, + policy, + parse_retry_after_seconds(response.headers()).map(Duration::from_secs), + ); + tokio::time::sleep(delay).await; + attempt = attempt.saturating_add(1); + continue; + } + + return Ok(response); + } + Err(source) => { + if attempt < max_attempts && should_retry_transport_error(&source) { + let delay = compute_retry_delay(attempt, policy, None); + tokio::time::sleep(delay).await; + attempt = attempt.saturating_add(1); + continue; + } + + return Err(transport_error(operation, source)); + } + } + } +} + +pub fn map_status_error(operation: &str, status: StatusCode) -> RegistryError { + RegistryError::Transport { + message: format!("{operation} returned status {status}"), + } +} + +pub async fn parse_json(response: Response, operation: &str) -> Result +where + T: DeserializeOwned, +{ + response + .json() + .await + .map_err(|source| RegistryError::InvalidResponse { + message: format!("failed to parse {operation} JSON: {source}"), + }) +} + +pub fn transport_error(operation: &str, source: reqwest::Error) -> RegistryError { + RegistryError::Transport { + message: format!("{operation} request failed: {source}"), + } +} + +fn should_retry_status(status: StatusCode) -> bool { + status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error() +} + +fn should_retry_transport_error(error: &reqwest::Error) -> bool { + error.is_connect() || error.is_timeout() || error.is_request() +} + +fn parse_retry_after_seconds(headers: &HeaderMap) -> Option { + let raw = headers.get("retry-after")?.to_str().ok()?.trim(); + raw.parse::().ok().map(|value| value.max(1)) +} + +fn compute_retry_delay( + attempt: u8, + policy: RetryPolicy, + retry_after: Option, +) -> Duration { + let fallback = exponential_backoff(attempt, policy.initial_backoff, policy.max_backoff); + match retry_after { + Some(delay) => { + let bounded = if delay > policy.max_backoff { + policy.max_backoff + } else { + delay + }; + if bounded.is_zero() { + Duration::from_millis(1) + } else { + bounded + } + } + None => fallback, + } +} + +fn exponential_backoff(attempt: u8, initial_backoff: Duration, max_backoff: Duration) -> Duration { + let shift = u32::from(attempt.saturating_sub(1)).min(16); + let multiplier = 2u128.pow(shift); + let initial_ms = initial_backoff.as_millis(); + let raw_ms = initial_ms.saturating_mul(multiplier); + let max_ms = max_backoff.as_millis(); + let bounded_ms = raw_ms.min(max_ms); + let bounded_ms_u64 = u64::try_from(bounded_ms).unwrap_or(u64::MAX); + Duration::from_millis(bounded_ms_u64) +} + +#[cfg(test)] +mod tests { + use super::*; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + #[test] + fn exponential_backoff_caps_at_maximum() { + let delay = exponential_backoff(8, Duration::from_millis(100), Duration::from_secs(1)); + assert_eq!(delay, Duration::from_secs(1)); + } + + #[test] + fn compute_retry_delay_prefers_retry_after_when_present() { + let policy = RetryPolicy { + max_attempts: 3, + initial_backoff: Duration::from_millis(100), + max_backoff: Duration::from_secs(5), + }; + + let delay = compute_retry_delay(1, policy, Some(Duration::from_secs(2))); + assert_eq!(delay, Duration::from_secs(2)); + } + + #[tokio::test] + async fn send_with_retry_retries_retryable_statuses() { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/retry")) + .respond_with(ResponseTemplate::new(429).insert_header("retry-after", "1")) + .expect(2) + .mount(&server) + .await; + + let client = build_http_client(); + let url = format!("{}/retry", server.uri()); + let response = send_with_retry( + || client.get(&url), + "retry test", + RetryPolicy { + max_attempts: 2, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(10), + }, + ) + .await + .expect("request should complete with response"); + + assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS); + } + + #[tokio::test] + async fn send_with_retry_retries_transport_errors() { + let client = build_http_client(); + let mut attempts = 0usize; + let err = send_with_retry( + || { + attempts = attempts.saturating_add(1); + client.get("http://127.0.0.1:9") + }, + "transport retry test", + RetryPolicy { + max_attempts: 2, + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(2), + }, + ) + .await + .expect_err("transport errors should bubble up after retries"); + + assert!(matches!(err, RegistryError::Transport { .. })); + assert_eq!(attempts, 2); + } +} diff --git a/crates/osv/Cargo.toml b/crates/osv/Cargo.toml index b823ba9..2919667 100644 --- a/crates/osv/Cargo.toml +++ b/crates/osv/Cargo.toml @@ -7,6 +7,7 @@ edition.workspace = true reqwest.workspace = true serde.workspace = true safe-pkgs-core = { path = "../core" } +safe-pkgs-registry-http = { path = "../http" } [dev-dependencies] tokio.workspace = true diff --git a/crates/osv/src/lib.rs b/crates/osv/src/lib.rs index 9bb26fe..9048e9b 100644 --- a/crates/osv/src/lib.rs +++ b/crates/osv/src/lib.rs @@ -1,8 +1,11 @@ -use reqwest::{Client, StatusCode}; +use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use std::env; use safe_pkgs_core::{PackageAdvisory, RegistryEcosystem, RegistryError}; +use safe_pkgs_registry_http::{ + RetryPolicy, build_http_client, map_status_error, parse_json, send_with_retry, +}; const OSV_API_URL: &str = "https://api.osv.dev/v1/query"; @@ -22,6 +25,7 @@ async fn query_advisories_with_url( ecosystem: RegistryEcosystem, api_url: &str, ) -> Result, RegistryError> { + let http = build_http_client(); let body = OsvQueryRequest { package: OsvPackage { name: package_name.to_string(), @@ -30,38 +34,22 @@ async fn query_advisories_with_url( version: version.to_string(), }; - let response = Client::new() - .post(api_url) - .json(&body) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query OSV advisory API: {e}"), - })?; + let response = send_with_retry( + || http.post(api_url).json(&body), + "OSV advisory API", + RetryPolicy::default(), + ) + .await?; if response.status() == StatusCode::NOT_FOUND { return Ok(Vec::new()); } - if response.status().is_server_error() { - return Err(RegistryError::Transport { - message: format!("OSV advisory API server error {}", response.status()), - }); - } - if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!("OSV advisory API returned status {}", response.status()), - }); + return Err(map_status_error("OSV advisory API", response.status())); } - let body: OsvQueryResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse OSV advisory response JSON: {e}"), - })?; + let body: OsvQueryResponse = parse_json(response, "OSV advisory response").await?; Ok(body .vulns @@ -176,7 +164,7 @@ mod tests { .await .expect_err("500 should be treated as transport error"); assert!(matches!(err, RegistryError::Transport { .. })); - assert!(err.to_string().contains("server error 500")); + assert!(err.to_string().contains("status 500")); } #[tokio::test] diff --git a/crates/registry/cargo/Cargo.toml b/crates/registry/cargo/Cargo.toml index 79742e0..43fcedd 100644 --- a/crates/registry/cargo/Cargo.toml +++ b/crates/registry/cargo/Cargo.toml @@ -12,6 +12,7 @@ tokio.workspace = true toml.workspace = true safe-pkgs-core = { path = "../../core" } safe-pkgs-osv = { path = "../../osv" } +safe-pkgs-registry-http = { path = "../../http" } [dev-dependencies] wiremock.workspace = true diff --git a/crates/registry/cargo/src/registry.rs b/crates/registry/cargo/src/registry.rs index 53fe7be..8c15087 100644 --- a/crates/registry/cargo/src/registry.rs +++ b/crates/registry/cargo/src/registry.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use reqwest::{Client, StatusCode}; +use reqwest::StatusCode; use serde::Deserialize; use std::collections::BTreeMap; use std::sync::Arc; @@ -11,13 +11,15 @@ use safe_pkgs_core::{ RegistryError, }; use safe_pkgs_osv::query_advisories; +use safe_pkgs_registry_http::{ + RetryPolicy, build_http_client, map_status_error, parse_json, send_with_retry, +}; -const CRATES_IO_USER_AGENT: &str = concat!("safe-pkgs/", env!("CARGO_PKG_VERSION")); const CRATES_PAGE_SIZE: usize = 100; #[derive(Clone)] pub struct CargoRegistryClient { - http: Client, + http: reqwest::Client, api_base_url: String, popular_names_cache: Arc>>>, } @@ -25,7 +27,7 @@ pub struct CargoRegistryClient { impl CargoRegistryClient { pub fn new() -> Self { Self { - http: Client::new(), + http: build_http_client(), api_base_url: "https://crates.io/api/v1".to_string(), popular_names_cache: Arc::new(RwLock::new(None)), } @@ -50,15 +52,12 @@ impl RegistryClient for CargoRegistryClient { self.api_base_url.trim_end_matches('/'), package ); - let response = self - .http - .get(&url) - .header("User-Agent", CRATES_IO_USER_AGENT) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query crates.io API: {e}"), - })?; + let response = send_with_retry( + || self.http.get(&url), + "crates.io API", + RetryPolicy::default(), + ) + .await?; if response.status() == StatusCode::NOT_FOUND { return Err(RegistryError::NotFound { @@ -68,18 +67,10 @@ impl RegistryClient for CargoRegistryClient { } if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!("crates.io API returned status {}", response.status()), - }); + return Err(map_status_error("crates.io API", response.status())); } - let body: CrateDetailResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse crates.io response JSON: {e}"), - })?; + let body: CrateDetailResponse = parse_json(response, "crates.io response").await?; let latest = body .krate @@ -123,33 +114,22 @@ impl RegistryClient for CargoRegistryClient { self.api_base_url.trim_end_matches('/'), package ); - let response = self - .http - .get(&url) - .header("User-Agent", CRATES_IO_USER_AGENT) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query crates.io API: {e}"), - })?; + let response = send_with_retry( + || self.http.get(&url), + "crates.io API", + RetryPolicy::default(), + ) + .await?; if response.status() == StatusCode::NOT_FOUND { return Ok(None); } if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!("crates.io API returned status {}", response.status()), - }); + return Err(map_status_error("crates.io API", response.status())); } - let body: CrateDownloadsResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse crates.io response JSON: {e}"), - })?; + let body: CrateDownloadsResponse = parse_json(response, "crates.io response").await?; Ok(body.krate.recent_downloads) } @@ -177,37 +157,26 @@ impl RegistryClient for CargoRegistryClient { while names.len() < limit { let url = format!("{}/crates", self.api_base_url.trim_end_matches('/')); let per_page = CRATES_PAGE_SIZE.min(limit.saturating_sub(names.len())); - let response = self - .http - .get(&url) - .header("User-Agent", CRATES_IO_USER_AGENT) - .query(&[ - ("page", page.to_string()), - ("per_page", per_page.to_string()), - ("sort", "downloads".to_string()), - ]) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query crates.io popular crates index: {e}"), - })?; + let query = vec![ + ("page", page.to_string()), + ("per_page", per_page.to_string()), + ("sort", "downloads".to_string()), + ]; + let response = send_with_retry( + || self.http.get(&url).query(&query), + "crates.io popular crates index", + RetryPolicy::default(), + ) + .await?; if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!( - "crates.io popular crates index returned status {}", - response.status() - ), - }); + return Err(map_status_error( + "crates.io popular crates index", + response.status(), + )); } - let body: CratesListResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse crates.io list response JSON: {e}"), - })?; + let body: CratesListResponse = parse_json(response, "crates.io list response").await?; if body.crates.is_empty() { break; @@ -291,7 +260,7 @@ mod tests { fn test_client(base_url: &str) -> CargoRegistryClient { CargoRegistryClient { - http: Client::new(), + http: build_http_client(), api_base_url: base_url.to_string(), popular_names_cache: Arc::new(RwLock::new(None)), } diff --git a/crates/registry/npm/Cargo.toml b/crates/registry/npm/Cargo.toml index d6ed0d3..0d3345b 100644 --- a/crates/registry/npm/Cargo.toml +++ b/crates/registry/npm/Cargo.toml @@ -13,6 +13,7 @@ serde_json.workspace = true tokio.workspace = true safe-pkgs-core = { path = "../../core" } safe-pkgs-osv = { path = "../../osv" } +safe-pkgs-registry-http = { path = "../../http" } [dev-dependencies] wiremock.workspace = true diff --git a/crates/registry/npm/src/registry.rs b/crates/registry/npm/src/registry.rs index 07ac95f..71c8c8e 100644 --- a/crates/registry/npm/src/registry.rs +++ b/crates/registry/npm/src/registry.rs @@ -1,11 +1,10 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use reqwest::{Client, StatusCode}; +use reqwest::StatusCode; use serde::Deserialize; use std::collections::{BTreeMap, HashMap, HashSet}; use std::env; use std::sync::Arc; -use std::time::Duration; use tokio::sync::RwLock; use safe_pkgs_core::{ @@ -13,15 +12,17 @@ use safe_pkgs_core::{ RegistryError, }; use safe_pkgs_osv::query_advisories; +use safe_pkgs_registry_http::{ + RetryPolicy, build_http_client, map_status_error, parse_json, send_with_retry, +}; const NPMS_POPULAR_QUERY: &str = "not:deprecated"; const NPMS_PAGE_SIZE: usize = 250; const NPM_BULK_DOWNLOAD_MAX_PACKAGES: usize = 128; -const NPM_DOWNLOAD_API_MAX_RETRY_ATTEMPTS: u8 = 2; #[derive(Clone)] pub struct NpmRegistryClient { - http: Client, + http: reqwest::Client, base_url: String, downloads_api_base_url: String, popular_index_api_base_url: String, @@ -32,7 +33,7 @@ pub struct NpmRegistryClient { impl NpmRegistryClient { pub fn new() -> Self { Self { - http: Client::new(), + http: build_http_client(), base_url: env::var("SAFE_PKGS_NPM_REGISTRY_API_BASE_URL") .unwrap_or_else(|_| "https://registry.npmjs.org".to_string()), downloads_api_base_url: env::var("SAFE_PKGS_NPM_DOWNLOADS_API_BASE_URL") @@ -81,31 +82,22 @@ impl NpmRegistryClient { joined ); - let response = - self.http - .get(&url) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query npm bulk downloads API: {e}"), - })?; + let response = send_with_retry( + || self.http.get(&url), + "npm bulk downloads API", + RetryPolicy::default(), + ) + .await?; if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!( - "npm bulk downloads API returned status {}", - response.status() - ), - }); + return Err(map_status_error( + "npm bulk downloads API", + response.status(), + )); } let body: NpmBulkDownloadsResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse npm bulk downloads response JSON: {e}"), - })?; + parse_json(response, "npm bulk downloads response").await?; let mut cache = self.prefetched_downloads.write().await; for item in body.downloads { @@ -137,14 +129,12 @@ impl RegistryClient for NpmRegistryClient { let encoded_name = Self::encode_package_name(package); let url = format!("{}/{}", self.base_url.trim_end_matches('/'), encoded_name); - let response = self - .http - .get(&url) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query npm registry: {e}"), - })?; + let response = send_with_retry( + || self.http.get(&url), + "npm registry", + RetryPolicy::default(), + ) + .await?; if response.status() == StatusCode::NOT_FOUND { return Err(RegistryError::NotFound { @@ -154,18 +144,10 @@ impl RegistryClient for NpmRegistryClient { } if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!("npm registry returned status {}", response.status()), - }); + return Err(map_status_error("npm registry", response.status())); } - let body: NpmPackageResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse npm response JSON: {e}"), - })?; + let body: NpmPackageResponse = parse_json(response, "npm registry response").await?; let latest = body .dist_tags @@ -219,30 +201,12 @@ impl RegistryClient for NpmRegistryClient { encoded_name ); - let mut attempts = 0u8; - let response = loop { - attempts = attempts.saturating_add(1); - let response = - self.http - .get(&url) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query npm downloads API: {e}"), - })?; - - if response.status() == StatusCode::TOO_MANY_REQUESTS - && attempts < NPM_DOWNLOAD_API_MAX_RETRY_ATTEMPTS - { - let retry_seconds = parse_retry_after_seconds(response.headers()) - .unwrap_or(1) - .clamp(1, 5); - tokio::time::sleep(Duration::from_secs(retry_seconds)).await; - continue; - } - - break response; - }; + let response = send_with_retry( + || self.http.get(&url), + "npm downloads API", + RetryPolicy::default(), + ) + .await?; if response.status() == StatusCode::NOT_FOUND { let mut cache = self.prefetched_downloads.write().await; @@ -251,18 +215,10 @@ impl RegistryClient for NpmRegistryClient { } if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!("npm downloads API returned status {}", response.status()), - }); + return Err(map_status_error("npm downloads API", response.status())); } - let body: NpmDownloadsResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse npm downloads response JSON: {e}"), - })?; + let body: NpmDownloadsResponse = parse_json(response, "npm downloads response").await?; let mut cache = self.prefetched_downloads.write().await; cache.insert(package.to_string(), body.downloads); @@ -297,36 +253,23 @@ impl RegistryClient for NpmRegistryClient { self.popular_index_api_base_url.trim_end_matches('/') ); let size = NPMS_PAGE_SIZE.min(limit.saturating_sub(names.len())); - let response = self - .http - .get(url) - .query(&[ - ("q", NPMS_POPULAR_QUERY.to_string()), - ("size", size.to_string()), - ("from", from.to_string()), - ]) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query npms popularity index: {e}"), - })?; + let query = vec![ + ("q", NPMS_POPULAR_QUERY.to_string()), + ("size", size.to_string()), + ("from", from.to_string()), + ]; + let response = send_with_retry( + || self.http.get(&url).query(&query), + "npms popularity index", + RetryPolicy::default(), + ) + .await?; if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!( - "npms popularity index returned status {}", - response.status() - ), - }); + return Err(map_status_error("npms popularity index", response.status())); } - let body: NpmsSearchResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse npms search response JSON: {e}"), - })?; + let body: NpmsSearchResponse = parse_json(response, "npms search response").await?; if body.results.is_empty() { break; @@ -424,11 +367,6 @@ struct NpmsPackage { name: String, } -fn parse_retry_after_seconds(headers: &reqwest::header::HeaderMap) -> Option { - let raw = headers.get("retry-after")?.to_str().ok()?; - raw.parse::().ok() -} - #[derive(Debug, Deserialize)] struct NpmBulkDownloadsResponse { #[serde(default)] @@ -444,13 +382,12 @@ struct NpmBulkDownloadItem { #[cfg(test)] mod tests { use super::*; - use reqwest::header::{HeaderMap, HeaderValue}; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; fn test_client(base_url: &str) -> NpmRegistryClient { NpmRegistryClient { - http: Client::new(), + http: build_http_client(), base_url: base_url.to_string(), downloads_api_base_url: base_url.to_string(), popular_index_api_base_url: base_url.to_string(), @@ -468,16 +405,6 @@ mod tests { assert_eq!(NpmRegistryClient::encode_package_name("lodash"), "lodash"); } - #[test] - fn parse_retry_after_reads_valid_seconds_only() { - let mut headers = HeaderMap::new(); - headers.insert("retry-after", HeaderValue::from_static("3")); - assert_eq!(parse_retry_after_seconds(&headers), Some(3)); - - headers.insert("retry-after", HeaderValue::from_static("not-a-number")); - assert_eq!(parse_retry_after_seconds(&headers), None); - } - #[tokio::test] async fn fetch_package_parses_scripts_and_deprecated_versions() { let mock_server = MockServer::start().await; diff --git a/crates/registry/pypi/Cargo.toml b/crates/registry/pypi/Cargo.toml index f73ed91..6cc9108 100644 --- a/crates/registry/pypi/Cargo.toml +++ b/crates/registry/pypi/Cargo.toml @@ -12,6 +12,7 @@ tokio.workspace = true toml.workspace = true safe-pkgs-core = { path = "../../core" } safe-pkgs-osv = { path = "../../osv" } +safe-pkgs-registry-http = { path = "../../http" } [dev-dependencies] wiremock.workspace = true diff --git a/crates/registry/pypi/src/registry.rs b/crates/registry/pypi/src/registry.rs index d385532..a582f1d 100644 --- a/crates/registry/pypi/src/registry.rs +++ b/crates/registry/pypi/src/registry.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use reqwest::{Client, StatusCode}; +use reqwest::StatusCode; use serde::Deserialize; use std::collections::{BTreeMap, HashSet}; use std::env; @@ -12,8 +12,10 @@ use safe_pkgs_core::{ RegistryError, }; use safe_pkgs_osv::query_advisories; +use safe_pkgs_registry_http::{ + RetryPolicy, build_http_client, map_status_error, parse_json, send_with_retry, +}; -const PYPI_USER_AGENT: &str = concat!("safe-pkgs/", env!("CARGO_PKG_VERSION")); const DEFAULT_PYPI_API_BASE_URL: &str = "https://pypi.org/pypi"; const DEFAULT_PYPI_DOWNLOADS_API_BASE_URL: &str = "https://pypistats.org/api/packages"; const DEFAULT_PYPI_POPULAR_INDEX_URL: &str = @@ -21,7 +23,7 @@ const DEFAULT_PYPI_POPULAR_INDEX_URL: &str = #[derive(Clone)] pub struct PypiRegistryClient { - http: Client, + http: reqwest::Client, package_api_base_url: String, downloads_api_base_url: String, popular_index_url: String, @@ -31,7 +33,7 @@ pub struct PypiRegistryClient { impl PypiRegistryClient { pub fn new() -> Self { Self { - http: Client::new(), + http: build_http_client(), package_api_base_url: env::var("SAFE_PKGS_PYPI_PACKAGE_API_BASE_URL") .unwrap_or_else(|_| DEFAULT_PYPI_API_BASE_URL.to_string()), downloads_api_base_url: env::var("SAFE_PKGS_PYPI_DOWNLOADS_API_BASE_URL") @@ -61,15 +63,8 @@ impl RegistryClient for PypiRegistryClient { self.package_api_base_url.trim_end_matches('/'), package ); - let response = self - .http - .get(&url) - .header("User-Agent", PYPI_USER_AGENT) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query PyPI API: {e}"), - })?; + let response = + send_with_retry(|| self.http.get(&url), "PyPI API", RetryPolicy::default()).await?; if response.status() == StatusCode::NOT_FOUND { return Err(RegistryError::NotFound { @@ -79,18 +74,10 @@ impl RegistryClient for PypiRegistryClient { } if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!("PyPI API returned status {}", response.status()), - }); + return Err(map_status_error("PyPI API", response.status())); } - let body: PypiPackageResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse PyPI response JSON: {e}"), - })?; + let body: PypiPackageResponse = parse_json(response, "PyPI response").await?; let latest = body .info @@ -147,33 +134,22 @@ impl RegistryClient for PypiRegistryClient { self.downloads_api_base_url.trim_end_matches('/'), package ); - let response = self - .http - .get(&url) - .header("User-Agent", PYPI_USER_AGENT) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query PyPI downloads API: {e}"), - })?; + let response = send_with_retry( + || self.http.get(&url), + "PyPI downloads API", + RetryPolicy::default(), + ) + .await?; if response.status() == StatusCode::NOT_FOUND { return Ok(None); } if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!("PyPI downloads API returned status {}", response.status()), - }); + return Err(map_status_error("PyPI downloads API", response.status())); } - let body: PypiDownloadsResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse PyPI downloads response JSON: {e}"), - })?; + let body: PypiDownloadsResponse = parse_json(response, "PyPI downloads response").await?; Ok(body.data.last_week) } @@ -195,32 +171,18 @@ impl RegistryClient for PypiRegistryClient { } } - let response = self - .http - .get(&self.popular_index_url) - .header("User-Agent", PYPI_USER_AGENT) - .send() - .await - .map_err(|e| RegistryError::Transport { - message: format!("unable to query PyPI popularity index: {e}"), - })?; + let response = send_with_retry( + || self.http.get(&self.popular_index_url), + "PyPI popularity index", + RetryPolicy::default(), + ) + .await?; if !response.status().is_success() { - return Err(RegistryError::Transport { - message: format!( - "PyPI popularity index returned status {}", - response.status() - ), - }); + return Err(map_status_error("PyPI popularity index", response.status())); } - let body: TopPypiResponse = - response - .json() - .await - .map_err(|e| RegistryError::InvalidResponse { - message: format!("failed to parse PyPI popularity index JSON: {e}"), - })?; + let body: TopPypiResponse = parse_json(response, "PyPI popularity index response").await?; let mut names = Vec::new(); let mut seen = HashSet::new(); @@ -326,7 +288,7 @@ mod tests { fn test_client(base_url: &str) -> PypiRegistryClient { PypiRegistryClient { - http: Client::new(), + http: build_http_client(), package_api_base_url: base_url.to_string(), downloads_api_base_url: base_url.to_string(), popular_index_url: format!("{}/top.json", base_url.trim_end_matches('/')),