From 1bcf069803d7ed80d6dadf5914df13d5e4a7cfb0 Mon Sep 17 00:00:00 2001 From: Chris Pryer Date: Mon, 9 Mar 2026 22:48:35 -0400 Subject: [PATCH] Implement client event_hooks flow --- src/client.rs | 596 ++++++++++++++++++++++++++----------- tests/test_async_client.py | 52 ++++ tests/test_client.py | 50 ++++ 3 files changed, 530 insertions(+), 168 deletions(-) diff --git a/src/client.rs b/src/client.rs index 9e25e42..75e35da 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,7 +2,7 @@ use crate::auth::{PyBasicAuth, PyDigestAuth}; use crate::config::{PyLimits, PyTimeout}; use crate::cookies::PyCookies; use crate::json::json_dumps; -use crate::models::{version_str, PyHeaders, PyRequest, PyResponse, ResponseStream}; +use crate::models::{version_str, PyHeaders, PyRequest, PyResponse, PyURL, ResponseStream}; use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyByteArray, PyBytes, PyDict, PyList, PyTuple}; @@ -532,6 +532,12 @@ enum AuthKind { Digest(Py), } +#[derive(Default)] +struct EventHooks { + request: Vec>, + response: Vec>, +} + struct MountTransport { prefix: String, transport: Py, @@ -559,6 +565,159 @@ fn clone_auth_kind(py: Python<'_>, auth: &AuthKind) -> AuthKind { } } +fn parse_event_hook_list(hooks_obj: Bound<'_, PyAny>, hook_type: &str) -> PyResult>> { + if hooks_obj.is_none() { + return Ok(Vec::new()); + } + let iter = hooks_obj.try_iter().map_err(|_| { + PyTypeError::new_err(format!( + "event_hooks['{hook_type}'] must be an iterable of callables" + )) + })?; + let mut hooks = Vec::new(); + for hook in iter { + let hook = hook?; + if !hook.is_callable() { + return Err(PyTypeError::new_err(format!( + "event_hooks['{hook_type}'] entries must be callable" + ))); + } + hooks.push(hook.unbind()); + } + Ok(hooks) +} + +fn parse_event_hooks_arg(py: Python<'_>, event_hooks: Option>) -> PyResult { + let Some(event_hooks) = event_hooks else { + return Ok(EventHooks::default()); + }; + let bound = event_hooks.bind(py); + if bound.is_none() { + return Ok(EventHooks::default()); + } + + let request_hooks_obj = bound + .call_method1("get", ("request", PyList::empty(py))) + .map_err(|_| { + PyTypeError::new_err( + "event_hooks must be a mapping with optional 'request' and 'response' entries", + ) + })?; + let response_hooks_obj = bound + .call_method1("get", ("response", PyList::empty(py))) + .map_err(|_| { + PyTypeError::new_err( + "event_hooks must be a mapping with optional 'request' and 'response' entries", + ) + })?; + + Ok(EventHooks { + request: parse_event_hook_list(request_hooks_obj, "request")?, + response: parse_event_hook_list(response_hooks_obj, "response")?, + }) +} + +fn clone_hooks(py: Python<'_>, hooks: &[Py]) -> Vec> { + hooks.iter().map(|hook| hook.clone_ref(py)).collect() +} + +fn run_sync_request_hooks( + py: Python<'_>, + hooks: &[Py], + request: &Py, +) -> PyResult<()> { + for hook in hooks { + hook.bind(py).call1((request.clone_ref(py),))?; + } + Ok(()) +} + +fn run_sync_response_hooks( + py: Python<'_>, + hooks: &[Py], + response: &Py, +) -> PyResult<()> { + for hook in hooks { + hook.bind(py).call1((response.clone_ref(py),))?; + } + Ok(()) +} + +async fn run_async_request_hooks(hooks: Vec>, request: Py) -> PyResult<()> { + for hook in hooks { + let hook_call = Python::attach(|py| { + let hook_result = hook.bind(py).call1((request.clone_ref(py),))?; + pyo3_async_runtimes::tokio::into_future(hook_result) + })?; + let _ = hook_call.await?; + } + Ok(()) +} + +async fn run_async_response_hooks(hooks: Vec>, response: Py) -> PyResult<()> { + for hook in hooks { + let hook_call = Python::attach(|py| { + let hook_result = hook.bind(py).call1((response.clone_ref(py),))?; + pyo3_async_runtimes::tokio::into_future(hook_result) + })?; + let _ = hook_call.await?; + } + Ok(()) +} + +fn body_bytes_for_hook_request(body: &RequestBody) -> Vec { + match body { + RequestBody::Empty => Vec::new(), + RequestBody::Bytes(bytes) => bytes.clone(), + RequestBody::Json(json) => json.as_bytes().to_vec(), + RequestBody::Form(pairs) => form_encode_pairs(pairs).into_bytes(), + } +} + +fn body_content_type_for_hook_request(body: &RequestBody) -> Option<&'static str> { + match body { + RequestBody::Json(_) => Some("application/json"), + RequestBody::Form(_) => Some("application/x-www-form-urlencoded"), + RequestBody::Empty | RequestBody::Bytes(_) => None, + } +} + +fn build_hook_request( + py: Python<'_>, + method: &str, + url: &str, + headers: PyHeaders, + content: Vec, +) -> PyResult> { + let request = PyRequest { + method: method.to_uppercase(), + url: PyURL::from_str(url)?, + headers, + content, + py_stream: None, + extensions: PyDict::new(py).into_any().unbind(), + }; + Py::new(py, request) +} + +fn extract_request_parts( + py: Python<'_>, + request_obj: &Py, +) -> PyResult<(String, String, Vec<(String, String)>, Option>)> { + let request = request_obj.bind(py).borrow(); + let body = if request.content.is_empty() && request.py_stream.is_none() { + None + } else { + Some(request.read(py)?) + }; + Ok(( + request.method.clone(), + request.url.inner.to_string(), + request.headers.inner.clone(), + body, + )) +} + fn build_blocking_request( client: &reqwest::blocking::Client, method: &str, @@ -1050,6 +1209,7 @@ pub struct PyClient { default_auth: Option, transport: Option>, mounts: Vec, + event_hooks: EventHooks, } impl PyClient { @@ -1135,12 +1295,12 @@ impl PyClient { default_encoding: Option>, block_private_redirects: bool, ) -> PyResult { - let _ = event_hooks; let _ = default_encoding; let default_query = params_to_query(py, params)?; let cookie_pairs = parse_cookies_arg(py, cookies)?; let cert_identity = parse_cert_arg(py, cert)?; let mounts = parse_mounts_arg(py, mounts)?; + let event_hooks = parse_event_hooks_arg(py, event_hooks)?; let cookie_jar = Arc::new(reqwest::cookie::Jar::default()); let cookie_state = Arc::new(Mutex::new(CookieBindingState { pending_pairs: cookie_pairs, @@ -1211,6 +1371,7 @@ impl PyClient { default_auth, transport, mounts, + event_hooks, }) } @@ -1272,7 +1433,7 @@ impl PyClient { auth: Option>, timeout: Option>, follow_redirects: Option, - ) -> PyResult { + ) -> PyResult> { let follow = follow_redirects.unwrap_or(self.follow_redirects); let client = if follow == self.follow_redirects { self.get_client()?.clone() @@ -1308,95 +1469,126 @@ impl PyClient { let effective_auth = req_auth.as_ref().or(self.default_auth.as_ref()); let body = build_body(py, content, json, data)?; + let request_hooks = clone_hooks(py, &self.event_hooks.request); + let response_hooks = clone_hooks(py, &self.event_hooks.response); + + let mut hook_headers = self.default_headers.clone(); + if let Some(extra) = &extra_headers { + hook_headers.inner.extend(extra.inner.clone()); + } + if let Some(content_type) = body_content_type_for_hook_request(&body) { + hook_headers + .inner + .push(("content-type".to_string(), content_type.to_string())); + } + if let Some(AuthKind::Basic(header_val)) = effective_auth { + hook_headers + .inner + .push(("authorization".to_string(), header_val.clone())); + } + + let request_obj = build_hook_request( + py, + method, + &full_url, + hook_headers, + body_bytes_for_hook_request(&body), + )?; + run_sync_request_hooks(py, &request_hooks, &request_obj)?; + + let send_request = |request_obj: &Py| -> PyResult { + let (method_str, request_url, headers, body_bytes) = + extract_request_parts(py, request_obj)?; + self.bind_default_cookies_for_url(&request_url); + + let method = reqwest::Method::from_bytes(method_str.as_bytes()) + .map_err(|_| PyValueError::new_err("Invalid method"))?; + let mut builder = client.request(method, &request_url); + for (k, v) in &headers { + builder = builder.header(k.as_str(), v.as_str()); + } + if let Some(body_bytes) = body_bytes { + builder = builder.body(body_bytes); + } + if let Some(timeout) = req_timeout { + builder = builder.timeout(timeout); + } + crate::without_gil(|| builder.send()).map_err(crate::map_reqwest_error) + }; // DigestAuth: two-pass — first request without auth, retry with credentials on 401 if let Some(AuthKind::Digest(digest_py)) = effective_auth { let digest_py = digest_py.clone_ref(py); - let method_str = method.to_string(); - let url_str = { - // RFC 7616 §3.4: digest-uri is the Request-URI (path + query, no scheme/host) - if let Ok(parsed) = url::Url::parse(&full_url) { - match parsed.query() { - Some(q) => format!("{}?{}", parsed.path(), q), - None => parsed.path().to_string(), - } - } else { - full_url.clone() - } - }; - let full_url2 = full_url.clone(); - let default_headers2 = self.default_headers.clone(); - let extra_headers2 = extra_headers.clone(); - let client2 = client.clone(); - - let builder = build_blocking_request( - &client, - method, - &full_url, - extra_headers.as_ref(), - &self.default_headers, - body, - None, // no auth on first pass - req_timeout, - )?; let start = Instant::now(); - - let resp = - crate::without_gil(move || builder.send().map_err(crate::map_reqwest_error))?; - - if resp.status().as_u16() == 401 { - let www_auth = resp - .headers() - .get("www-authenticate") - .and_then(|v| v.to_str().ok()) - .unwrap_or("") - .to_string(); - + let resp = send_request(&request_obj)?; + let elapsed = start.elapsed().as_millis(); + let first_response = + PyResponse::from_blocking(resp, elapsed, Some(request_obj.clone_ref(py)))?; + let first_response_obj = Py::new(py, first_response)?; + let first_response_any = first_response_obj.clone_ref(py).into_any(); + run_sync_response_hooks(py, &response_hooks, &first_response_any)?; + + if first_response_obj.bind(py).borrow().status_code == 401 { + let www_auth = { + let response_ref = first_response_obj.bind(py).borrow(); + response_ref + .headers + .get("www-authenticate", None) + .unwrap_or_default() + }; + let (method_str, request_url, _, _) = extract_request_parts(py, &request_obj)?; + let url_str = { + // RFC 7616 §3.4: digest-uri is the Request-URI (path + query, no scheme/host) + if let Ok(parsed) = url::Url::parse(&request_url) { + match parsed.query() { + Some(q) => format!("{}?{}", parsed.path(), q), + None => parsed.path().to_string(), + } + } else { + request_url + } + }; let auth_header = { let digest_ref = digest_py.bind(py); let digest = digest_ref.borrow(); digest.compute_header(&method_str, &url_str, &www_auth)? }; - let builder2 = build_blocking_request( - &client2, - &method_str, - &full_url2, - extra_headers2.as_ref(), - &default_headers2, - RequestBody::Empty, - None, - req_timeout, - )?; - let builder2 = builder2.header("authorization", auth_header.as_str()); + let second_request_obj = Py::new(py, request_obj.bind(py).borrow().clone())?; + { + // Preserve existing behavior: digest retry does not resend request content. + let mut second_request = second_request_obj.bind(py).borrow_mut(); + second_request.content.clear(); + second_request.py_stream = None; + second_request.set_header("authorization", auth_header.as_str()); + } + run_sync_request_hooks(py, &request_hooks, &second_request_obj)?; - crate::without_gil(move || { - let resp2 = builder2.send().map_err(crate::map_reqwest_error)?; - let elapsed = start.elapsed().as_millis(); - PyResponse::from_blocking(resp2, elapsed, None) - }) - } else { + let start = Instant::now(); + let resp2 = send_request(&second_request_obj)?; let elapsed = start.elapsed().as_millis(); - PyResponse::from_blocking(resp, elapsed, None) + let second_response = PyResponse::from_blocking( + resp2, + elapsed, + Some(second_request_obj.clone_ref(py)), + )?; + let second_response_obj = Py::new(py, second_response)?; + let second_response_any = second_response_obj.clone_ref(py).into_any(); + run_sync_response_hooks(py, &response_hooks, &second_response_any)?; + Ok(second_response_obj) + } else { + Ok(first_response_obj) } } else { - let builder = build_blocking_request( - &client, - method, - &full_url, - extra_headers.as_ref(), - &self.default_headers, - body, - effective_auth, - req_timeout, - )?; let start = Instant::now(); - - // Release GIL while blocking on I/O - let result = crate::without_gil(|| builder.send()); - let resp = result.map_err(crate::map_reqwest_error)?; + let resp = send_request(&request_obj)?; let elapsed = start.elapsed().as_millis(); - PyResponse::from_blocking(resp, elapsed, None) + let response = + PyResponse::from_blocking(resp, elapsed, Some(request_obj.clone_ref(py)))?; + let response_obj = Py::new(py, response)?; + let response_any = response_obj.clone_ref(py).into_any(); + run_sync_response_hooks(py, &response_hooks, &response_any)?; + Ok(response_obj) } } @@ -1422,7 +1614,7 @@ impl PyClient { auth: Option>, timeout: Option>, follow_redirects: Option, - ) -> PyResult { + ) -> PyResult> { self.request( py, "GET", @@ -1459,7 +1651,7 @@ impl PyClient { auth: Option>, timeout: Option>, follow_redirects: Option, - ) -> PyResult { + ) -> PyResult> { self.request( py, "POST", @@ -1496,7 +1688,7 @@ impl PyClient { auth: Option>, timeout: Option>, follow_redirects: Option, - ) -> PyResult { + ) -> PyResult> { self.request( py, "PUT", @@ -1533,7 +1725,7 @@ impl PyClient { auth: Option>, timeout: Option>, follow_redirects: Option, - ) -> PyResult { + ) -> PyResult> { self.request( py, "PATCH", @@ -1570,7 +1762,7 @@ impl PyClient { auth: Option>, timeout: Option>, follow_redirects: Option, - ) -> PyResult { + ) -> PyResult> { self.request( py, "DELETE", @@ -1601,7 +1793,7 @@ impl PyClient { auth: Option>, timeout: Option>, follow_redirects: Option, - ) -> PyResult { + ) -> PyResult> { self.request( py, "HEAD", @@ -1632,7 +1824,7 @@ impl PyClient { auth: Option>, timeout: Option>, follow_redirects: Option, - ) -> PyResult { + ) -> PyResult> { self.request( py, "OPTIONS", @@ -1803,12 +1995,22 @@ impl PyClient { Some(ref a) => Some(extract_auth(py, a)?), None => None, }; - let default_auth = { + let (default_auth, request_hooks, response_hooks) = { let this = slf.borrow(); - this.default_auth.as_ref().map(|a| clone_auth_kind(py, a)) + ( + this.default_auth.as_ref().map(|a| clone_auth_kind(py, a)), + clone_hooks(py, &this.event_hooks.request), + clone_hooks(py, &this.event_hooks.response), + ) }; let effective_auth = req_auth.or(default_auth); let request_obj = Py::new(py, request.clone())?; + if let Some(AuthKind::Basic(header_val)) = &effective_auth { + let mut req_mut = request_obj.bind(py).borrow_mut(); + req_mut.set_header("authorization", header_val); + } + run_sync_request_hooks(py, &request_hooks, &request_obj)?; + let request_url = { let req_ref = request_obj.bind(py).borrow(); req_ref.url.inner.to_string() @@ -1817,10 +2019,6 @@ impl PyClient { let this = slf.borrow(); this.transport_for_url(py, &request_url) }; - if let Some(AuthKind::Basic(header_val)) = &effective_auth { - let mut req_mut = request_obj.bind(py).borrow_mut(); - req_mut.set_header("authorization", header_val); - } if let Some(transport) = transport_obj { let transport_bound = transport.into_bound(py).into_any(); if transport_bound.hasattr("handle_request")? { @@ -1831,7 +2029,9 @@ impl PyClient { py_response.request = Some(request_obj.clone_ref(py)); } } - return Ok(response); + let response_obj = response.unbind(); + run_sync_response_hooks(py, &response_hooks, &response_obj)?; + return Ok(response_obj.into_bound(py).into_any()); } } @@ -1853,23 +2053,8 @@ impl PyClient { this.cookie_jar.clone(), )? }; - let (method_str, url, mut headers, body) = { - let req_ref = request_obj.bind(py).borrow(); - let method_str = req_ref.method.clone(); - let url = req_ref.url.inner.to_string(); - let headers = req_ref.headers.inner.clone(); - let body = if req_ref.content.is_empty() && req_ref.py_stream.is_none() { - None - } else { - Some(req_ref.read(py)?) - }; - (method_str, url, headers, body) - }; - this.bind_default_cookies_for_url(&url); - if let Some(AuthKind::Basic(header_val)) = &effective_auth { - headers.retain(|(k, _)| k != "authorization"); - headers.push(("authorization".to_string(), header_val.clone())); - } + let (method_str, url, headers, body) = extract_request_parts(py, &request_obj)?; + this.bind_default_cookies_for_url(&request_url); let method = reqwest::Method::from_bytes(method_str.as_bytes()) .map_err(|_| PyValueError::new_err("Invalid method"))?; @@ -1890,6 +2075,8 @@ impl PyClient { PyResponse::from_blocking(response, elapsed, Some(request_obj.clone_ref(py)))? }; let response_obj = Py::new(py, py_response)?; + let response_any = response_obj.clone_ref(py).into_any(); + run_sync_response_hooks(py, &response_hooks, &response_any)?; Ok(response_obj.into_bound(py).into_any()) } @@ -2075,6 +2262,7 @@ pub struct PyAsyncClient { default_auth: Option, transport: Option>, mounts: Vec, + event_hooks: EventHooks, } impl PyAsyncClient { @@ -2207,12 +2395,12 @@ impl PyAsyncClient { default_encoding: Option>, block_private_redirects: bool, ) -> PyResult { - let _ = event_hooks; let _ = default_encoding; let default_query = params_to_query(py, params)?; let cookie_pairs = parse_cookies_arg(py, cookies)?; let cert_identity = parse_cert_arg(py, cert)?; let mounts = parse_mounts_arg(py, mounts)?; + let event_hooks = parse_event_hooks_arg(py, event_hooks)?; let cookie_jar = Arc::new(reqwest::cookie::Jar::default()); let cookie_state = Arc::new(Mutex::new(CookieBindingState { pending_pairs: cookie_pairs, @@ -2282,6 +2470,7 @@ impl PyAsyncClient { default_auth, transport, mounts, + event_hooks, }) } @@ -2334,6 +2523,8 @@ impl PyAsyncClient { let resolved_url = self.resolve_url(&url); let full_url = merge_url_query(&resolved_url, self.default_query.as_deref(), None); self.bind_default_cookies_for_url(&full_url); + let request_hooks = clone_hooks(py, &self.event_hooks.request); + let response_hooks = clone_hooks(py, &self.event_hooks.response); let extra_headers = match headers { None => None, @@ -2374,32 +2565,40 @@ impl PyAsyncClient { }; let req_timeout = timeout.or(self.timeout.read).map(Duration::from_secs_f64); - - let headers_vec: Vec<(String, String)> = merged_headers.inner.clone(); + if let Some(ref ct) = content_type { + merged_headers + .inner + .push(("content-type".to_string(), ct.clone())); + } + if let Some(ref header_val) = auth_header { + merged_headers + .inner + .push(("authorization".to_string(), header_val.clone())); + } + let request_obj = build_hook_request( + py, + &method, + &full_url, + merged_headers, + body_bytes.clone().unwrap_or_default(), + )?; + let request_obj_for_hooks = request_obj.clone_ref(py); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let mut builder = client.request( - reqwest::Method::from_bytes(method.to_uppercase().as_bytes()) - .map_err(|_| PyValueError::new_err("Invalid method"))?, - &full_url, - ); + run_async_request_hooks(request_hooks, request_obj_for_hooks).await?; + let (method_str, url, headers, body) = + Python::attach(|py| extract_request_parts(py, &request_obj))?; - for (k, v) in &headers_vec { - builder = builder.header(k.as_str(), v.as_str()); - } + let method = reqwest::Method::from_bytes(method_str.as_bytes()) + .map_err(|_| PyValueError::new_err("Invalid method"))?; + let mut builder = client.request(method, &url); - if let Some(ref ct) = content_type { - builder = builder.header("content-type", ct.as_str()); + for (k, v) in &headers { + builder = builder.header(k.as_str(), v.as_str()); } - - if let Some(body) = body_bytes { + if let Some(body) = body { builder = builder.body(body); } - - if let Some(ref header_val) = auth_header { - builder = builder.header("authorization", header_val.as_str()); - } - if let Some(dur) = req_timeout { builder = builder.timeout(dur); } @@ -2407,7 +2606,17 @@ impl PyAsyncClient { let start = Instant::now(); let response = builder.send().await.map_err(crate::map_reqwest_error)?; let elapsed = start.elapsed().as_millis(); - convert_async_response(response, elapsed, None).await + let request_for_response = Python::attach(|py| request_obj.clone_ref(py)); + let response = + convert_async_response(response, elapsed, Some(request_for_response)).await?; + let response_obj = + Python::attach(|py| Py::new(py, response).map(|obj| obj.into_any()))?; + run_async_response_hooks( + response_hooks, + Python::attach(|py| response_obj.clone_ref(py)), + ) + .await?; + Ok(response_obj) }) } @@ -2761,12 +2970,20 @@ impl PyAsyncClient { Some(ref a) => Some(extract_auth(py, a)?), None => None, }; - let default_auth = { + let (default_auth, request_hooks, response_hooks) = { let this = slf.borrow(); - this.default_auth.as_ref().map(|a| clone_auth_kind(py, a)) + ( + this.default_auth.as_ref().map(|a| clone_auth_kind(py, a)), + clone_hooks(py, &this.event_hooks.request), + clone_hooks(py, &this.event_hooks.response), + ) }; let effective_auth = req_auth.or(default_auth); let request_obj = Py::new(py, request.clone())?; + if let Some(AuthKind::Basic(header_val)) = &effective_auth { + let mut req_mut = request_obj.bind(py).borrow_mut(); + req_mut.set_header("authorization", header_val); + } let request_url = { let req_ref = request_obj.bind(py).borrow(); req_ref.url.inner.to_string() @@ -2775,25 +2992,73 @@ impl PyAsyncClient { let this = slf.borrow(); this.transport_for_url(py, &request_url) }; - if let Some(AuthKind::Basic(header_val)) = &effective_auth { - let mut req_mut = request_obj.bind(py).borrow_mut(); - req_mut.set_header("authorization", header_val); - } if let Some(transport) = transport_obj { let transport_bound = transport.into_bound(py).into_any(); if transport_bound.hasattr("handle_async_request")? { - return transport_bound - .call_method1("handle_async_request", (request_obj.clone_ref(py),)); + let transport_obj = transport_bound.unbind(); + let request_obj_for_hooks = request_obj.clone_ref(py); + let request_obj_for_transport = request_obj.clone_ref(py); + let request_obj_for_response = request_obj.clone_ref(py); + return pyo3_async_runtimes::tokio::future_into_py(py, async move { + run_async_request_hooks(request_hooks, request_obj_for_hooks).await?; + let awaitable = Python::attach(|py| { + transport_obj + .bind(py) + .call_method1( + "handle_async_request", + (request_obj_for_transport.clone_ref(py),), + ) + .map(|obj| obj.unbind()) + })?; + let response_obj = Python::attach(|py| { + pyo3_async_runtimes::tokio::into_future(awaitable.into_bound(py)) + })? + .await?; + Python::attach(|py| { + if let Ok(mut py_response) = + response_obj.bind(py).extract::>() + { + if py_response.request.is_none() { + py_response.request = Some(request_obj_for_response.clone_ref(py)); + } + } + Ok::<(), PyErr>(()) + })?; + run_async_response_hooks( + response_hooks, + Python::attach(|py| response_obj.clone_ref(py)), + ) + .await?; + Ok(response_obj) + }); } if transport_bound.hasattr("handle_request")? { - let response = - transport_bound.call_method1("handle_request", (request_obj.clone_ref(py),))?; - if let Ok(mut py_response) = response.extract::>() { - if py_response.request.is_none() { - py_response.request = Some(request_obj.clone_ref(py)); - } - } - return immediate_awaitable(py, response.unbind()); + let transport_obj = transport_bound.unbind(); + let request_obj_for_hooks = request_obj.clone_ref(py); + let request_obj_for_transport = request_obj.clone_ref(py); + let request_obj_for_response = request_obj.clone_ref(py); + return pyo3_async_runtimes::tokio::future_into_py(py, async move { + run_async_request_hooks(request_hooks, request_obj_for_hooks).await?; + let response_obj = Python::attach(|py| { + let response = transport_obj.bind(py).call_method1( + "handle_request", + (request_obj_for_transport.clone_ref(py),), + )?; + if let Ok(mut py_response) = response.extract::>() + { + if py_response.request.is_none() { + py_response.request = Some(request_obj_for_response.clone_ref(py)); + } + } + Ok::, PyErr>(response.unbind()) + })?; + run_async_response_hooks( + response_hooks, + Python::attach(|py| response_obj.clone_ref(py)), + ) + .await?; + Ok(response_obj) + }); } } @@ -2815,27 +3080,16 @@ impl PyAsyncClient { this.cookie_jar.clone(), )? }; - let (method_str, url, mut headers, body) = { - let req_ref = request_obj.bind(py).borrow(); - let method_str = req_ref.method.clone(); - let url = req_ref.url.inner.to_string(); - let headers = req_ref.headers.inner.clone(); - let body = if req_ref.content.is_empty() && req_ref.py_stream.is_none() { - None - } else { - Some(req_ref.read(py)?) - }; - (method_str, url, headers, body) - }; - this.bind_default_cookies_for_url(&url); - if let Some(AuthKind::Basic(header_val)) = &effective_auth { - headers.retain(|(k, _)| k != "authorization"); - headers.push(("authorization".to_string(), header_val.clone())); - } + this.bind_default_cookies_for_url(&request_url); + drop(this); + let request_obj_for_hooks = request_obj.clone_ref(py); let request_obj_stream = request_obj.clone_ref(py); let request_obj_regular = request_obj.clone_ref(py); pyo3_async_runtimes::tokio::future_into_py(py, async move { + run_async_request_hooks(request_hooks, request_obj_for_hooks).await?; + let (method_str, url, headers, body) = + Python::attach(|py| extract_request_parts(py, &request_obj))?; let method = reqwest::Method::from_bytes(method_str.as_bytes()) .map_err(|_| PyValueError::new_err("Invalid method"))?; let mut builder = client.request(method, &url); @@ -2848,15 +3102,21 @@ impl PyAsyncClient { let start = Instant::now(); let response = builder.send().await.map_err(crate::map_reqwest_error)?; let elapsed = start.elapsed().as_millis(); - if stream { - Ok(PyResponse::from_async_stream( - response, - elapsed, - Some(request_obj_stream), - )) + let response_obj = if stream { + let response = + PyResponse::from_async_stream(response, elapsed, Some(request_obj_stream)); + Python::attach(|py| Py::new(py, response).map(|obj| obj.into_any()))? } else { - convert_async_response(response, elapsed, Some(request_obj_regular)).await - } + let response = + convert_async_response(response, elapsed, Some(request_obj_regular)).await?; + Python::attach(|py| Py::new(py, response).map(|obj| obj.into_any()))? + }; + run_async_response_hooks( + response_hooks, + Python::attach(|py| response_obj.clone_ref(py)), + ) + .await?; + Ok(response_obj) }) } diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 0316983..4c818a6 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -508,6 +508,58 @@ async def test_async_client_rejects_mount_with_non_string_key(): httprs.AsyncClient(mounts={1: object()}) +@pytest.mark.anyio +async def test_async_client_request_event_hooks_apply_to_network_requests(server): + events = [] + + async def on_request(request): + events.append(("request", request.method)) + request.set_header("x-event-hook", "async") + + async def on_response(response): + events.append(("response", response.status_code)) + + async with httprs.AsyncClient( + event_hooks={"request": [on_request], "response": [on_response]} + ) as client: + response = await client.get(server.url + "/echo_headers") + + assert response.status_code == 200 + assert response.json().get("x-event-hook") == "async" + assert events == [("request", "GET"), ("response", 200)] + + +@pytest.mark.anyio +async def test_async_send_event_hooks_apply_to_mounted_transport(): + events = [] + + class ApiTransport: + async def handle_async_request(self, request): + return httprs.Response( + 217, + text=request.headers.get("x-event-hook", ""), + request=request, + ) + + async def on_request(request): + request.set_header("x-event-hook", "mounted-async") + events.append("request") + + async def on_response(response): + events.append(("response", response.status_code, response.text)) + + async with httprs.AsyncClient( + event_hooks={"request": [on_request], "response": [on_response]}, + mounts={"https://example.com/api/": ApiTransport()}, + ) as client: + request = client.build_request("GET", "https://example.com/api/items") + response = await client.send(request) + + assert response.status_code == 217 + assert response.text == "mounted-async" + assert events == ["request", ("response", 217, "mounted-async")] + + @pytest.mark.anyio async def test_send_auth_argument_basic(server): async with httprs.AsyncClient() as client: diff --git a/tests/test_client.py b/tests/test_client.py index b3e0367..c8cbfcf 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -291,6 +291,56 @@ def test_client_rejects_mount_with_non_string_key(): httprs.Client(mounts={1: object()}) +def test_client_request_event_hooks_apply_to_network_requests(server): + events = [] + + def on_request(request): + events.append(("request", request.method)) + request.set_header("x-event-hook", "sync") + + def on_response(response): + events.append(("response", response.status_code)) + + with httprs.Client( + event_hooks={"request": [on_request], "response": [on_response]} + ) as client: + response = client.get(server.url + "/echo_headers") + + assert response.status_code == 200 + assert response.json().get("x-event-hook") == "sync" + assert events == [("request", "GET"), ("response", 200)] + + +def test_client_send_event_hooks_apply_to_mounted_transport(): + events = [] + + class ApiTransport: + def handle_request(self, request): + return httprs.Response( + 299, + text=request.headers.get("x-event-hook", ""), + request=request, + ) + + def on_request(request): + request.set_header("x-event-hook", "mounted") + events.append("request") + + def on_response(response): + events.append(("response", response.status_code, response.text)) + + with httprs.Client( + event_hooks={"request": [on_request], "response": [on_response]}, + mounts={"https://example.com/api/": ApiTransport()}, + ) as client: + request = client.build_request("GET", "https://example.com/api/items") + response = client.send(request) + + assert response.status_code == 299 + assert response.text == "mounted" + assert events == ["request", ("response", 299, "mounted")] + + def test_send_auth_argument_basic(server): with httprs.Client() as client: request = client.build_request("GET", server.url + "/echo_headers")