-
-
Notifications
You must be signed in to change notification settings - Fork 26
Handle unread streaming request bodies during 402 retries #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,15 +23,25 @@ | |||||||||
| def _clone_request_with_headers( | ||||||||||
| original: httpx.Request, | ||||||||||
| extra_headers: dict[str, str], | ||||||||||
| *, | ||||||||||
| content: bytes | None = None, | ||||||||||
| ) -> httpx.Request: | ||||||||||
| """Clone *original* and merge *extra_headers* into the copy.""" | ||||||||||
| """Clone *original* and merge *extra_headers* into the copy. | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| content: | ||||||||||
| Optional explicit body bytes to use for the cloned request. | ||||||||||
| When omitted, ``original.content`` is used and must already be | ||||||||||
| materialized by the caller. | ||||||||||
| """ | ||||||||||
| headers = dict(original.headers) | ||||||||||
| headers.update(extra_headers) | ||||||||||
| return httpx.Request( | ||||||||||
| method=original.method, | ||||||||||
| url=original.url, | ||||||||||
| headers=headers, | ||||||||||
| content=original.content, | ||||||||||
| content=original.content if content is None else content, | ||||||||||
| extensions=dict(original.extensions), | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
@@ -78,7 +88,15 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: | |||||||||
| logger.exception("x402: payment signing failed") | ||||||||||
| return response | ||||||||||
|
|
||||||||||
| retry = _clone_request_with_headers(request, payment_headers) | ||||||||||
| try: | ||||||||||
| body = request.content | ||||||||||
| except httpx.RequestNotRead: | ||||||||||
| # Some transports/proxies can short-circuit with 402 before consuming | ||||||||||
| # the request body. Ensure we materialize it so the retry is replayable. | ||||||||||
| body = request.read() | ||||||||||
|
|
||||||||||
| retry = _clone_request_with_headers(request, payment_headers, content=body) | ||||||||||
| response.close() | ||||||||||
| return self._inner.handle_request(retry) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original 402
Suggested change
|
||||||||||
|
|
||||||||||
| def close(self) -> None: | ||||||||||
|
|
@@ -128,7 +146,15 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: | |||||||||
| logger.exception("x402: payment signing failed") | ||||||||||
| return response | ||||||||||
|
|
||||||||||
| retry = _clone_request_with_headers(request, payment_headers) | ||||||||||
| try: | ||||||||||
| body = request.content | ||||||||||
| except httpx.RequestNotRead: | ||||||||||
| # Some transports/proxies can short-circuit with 402 before consuming | ||||||||||
| # the request body. Ensure we materialize it so the retry is replayable. | ||||||||||
| body = await request.aread() | ||||||||||
|
|
||||||||||
| retry = _clone_request_with_headers(request, payment_headers, content=body) | ||||||||||
| await response.aclose() | ||||||||||
| return await self._inner.handle_async_request(retry) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue on the async path — the 402 response is never closed before the retry.
Suggested change
|
||||||||||
|
|
||||||||||
| async def aclose(self) -> None: | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,18 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import httpx | ||
| import pytest | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import AsyncIterator, Iterator | ||
|
|
||
| from x402_openai._transport import _clone_request_with_headers | ||
| from x402_openai._transport import ( | ||
| AsyncX402Transport, | ||
| X402Transport, | ||
| _clone_request_with_headers, | ||
| ) | ||
|
|
||
|
|
||
| def test_clone_request_merges_headers() -> None: | ||
|
|
@@ -37,3 +47,286 @@ def test_clone_request_preserves_extensions() -> None: | |
| cloned = _clone_request_with_headers(original, {"x-payment": "signed"}) | ||
|
|
||
| assert cloned.extensions == original.extensions | ||
|
|
||
|
|
||
| class _FakeX402ClientSync: | ||
| def handle_402_response( | ||
| self, | ||
| headers: dict[str, str], | ||
| body: bytes, | ||
| ) -> tuple[dict[str, str], dict[str, str]]: | ||
| assert headers["x-402"] == "required" | ||
| assert body == b"challenge" | ||
| return {"x-payment": "signed"}, {} | ||
|
|
||
|
|
||
| class _FailingX402ClientSync: | ||
| def handle_402_response( | ||
| self, | ||
| headers: dict[str, str], | ||
| body: bytes, | ||
| ) -> tuple[dict[str, str], dict[str, str]]: | ||
| raise RuntimeError("signing failed") | ||
|
|
||
|
|
||
| class _ShortCircuit402Transport(httpx.BaseTransport): | ||
| def __init__(self) -> None: | ||
| self.calls = 0 | ||
| self.retry_body = b"" | ||
| self.retry_headers: dict[str, str] = {} | ||
|
|
||
| def handle_request(self, request: httpx.Request) -> httpx.Response: | ||
| self.calls += 1 | ||
| if self.calls == 1: | ||
| # Intentionally do not consume request body. | ||
| return httpx.Response(402, headers={"x-402": "required"}, content=b"challenge") | ||
|
|
||
| self.retry_body = request.read() | ||
| self.retry_headers = dict(request.headers) | ||
| return httpx.Response(200, content=b"ok") | ||
|
|
||
|
|
||
| class _PassthroughTransport(httpx.BaseTransport): | ||
| def __init__(self) -> None: | ||
| self.calls = 0 | ||
|
|
||
| def handle_request(self, request: httpx.Request) -> httpx.Response: | ||
| self.calls += 1 | ||
| return httpx.Response(200, content=b"ok") | ||
|
|
||
|
|
||
| class _CloseTracking402Transport(httpx.BaseTransport): | ||
| def __init__(self) -> None: | ||
| self.calls = 0 | ||
| self.first_response_closed = False | ||
|
|
||
| def handle_request(self, request: httpx.Request) -> httpx.Response: | ||
| self.calls += 1 | ||
| if self.calls == 1: | ||
| response = httpx.Response(402, headers={"x-402": "required"}, content=b"challenge") | ||
| original_close = response.close | ||
|
|
||
| def tracked_close() -> None: | ||
| self.first_response_closed = True | ||
| original_close() | ||
|
|
||
| response.close = tracked_close # type: ignore[method-assign] | ||
| return response | ||
| return httpx.Response(200, content=b"ok") | ||
|
|
||
|
|
||
| class _CloseDelegatingTransport(httpx.BaseTransport): | ||
| def __init__(self) -> None: | ||
| self.closed = False | ||
|
|
||
| def handle_request(self, request: httpx.Request) -> httpx.Response: | ||
| return httpx.Response(200, content=b"ok") | ||
|
|
||
| def close(self) -> None: | ||
| self.closed = True | ||
|
|
||
|
|
||
| def _iter_json() -> Iterator[bytes]: | ||
| yield b'{"prompt":' | ||
| yield b'"hi"}' | ||
|
|
||
|
|
||
| def test_sync_transport_retries_even_if_402_response_short_circuits_body() -> None: | ||
| inner = _ShortCircuit402Transport() | ||
| transport = X402Transport(_FakeX402ClientSync(), inner=inner) | ||
|
|
||
| request = httpx.Request("POST", "https://example.com/v1/chat", content=_iter_json()) | ||
| response = transport.handle_request(request) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert inner.calls == 2 | ||
| assert inner.retry_headers["x-payment"] == "signed" | ||
| assert inner.retry_body == b'{"prompt":"hi"}' | ||
|
|
||
|
|
||
| def test_sync_transport_passes_through_non_402_response() -> None: | ||
| inner = _PassthroughTransport() | ||
| transport = X402Transport(_FakeX402ClientSync(), inner=inner) | ||
|
|
||
| response = transport.handle_request(httpx.Request("GET", "https://example.com/v1/models")) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert inner.calls == 1 | ||
|
|
||
|
|
||
| def test_sync_transport_returns_original_402_when_signing_fails() -> None: | ||
| inner = _ShortCircuit402Transport() | ||
| transport = X402Transport(_FailingX402ClientSync(), inner=inner) | ||
|
|
||
| response = transport.handle_request( | ||
| httpx.Request("POST", "https://example.com/v1/chat", content=_iter_json()) | ||
| ) | ||
|
|
||
| assert response.status_code == 402 | ||
| assert inner.calls == 1 | ||
|
|
||
|
|
||
| def test_sync_transport_closes_original_402_before_retry() -> None: | ||
| inner = _CloseTracking402Transport() | ||
| transport = X402Transport(_FakeX402ClientSync(), inner=inner) | ||
|
|
||
| response = transport.handle_request( | ||
| httpx.Request("POST", "https://example.com/v1/chat", content=_iter_json()) | ||
| ) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert inner.calls == 2 | ||
| assert inner.first_response_closed is True | ||
|
|
||
|
|
||
| def test_sync_transport_close_delegates_to_inner_transport() -> None: | ||
| inner = _CloseDelegatingTransport() | ||
| transport = X402Transport(_FakeX402ClientSync(), inner=inner) | ||
|
|
||
| transport.close() | ||
|
|
||
| assert inner.closed is True | ||
|
|
||
|
|
||
| class _FakeX402ClientAsync: | ||
| async def handle_402_response( | ||
| self, | ||
| headers: dict[str, str], | ||
| body: bytes, | ||
| ) -> tuple[dict[str, str], dict[str, str]]: | ||
| assert headers["x-402"] == "required" | ||
| assert body == b"challenge" | ||
| return {"x-payment": "signed"}, {} | ||
|
|
||
|
|
||
| class _FailingX402ClientAsync: | ||
| async def handle_402_response( | ||
| self, | ||
| headers: dict[str, str], | ||
| body: bytes, | ||
| ) -> tuple[dict[str, str], dict[str, str]]: | ||
| raise RuntimeError("signing failed") | ||
|
|
||
|
|
||
| class _ShortCircuit402AsyncTransport(httpx.AsyncBaseTransport): | ||
| def __init__(self) -> None: | ||
| self.calls = 0 | ||
| self.retry_body = b"" | ||
| self.retry_headers: dict[str, str] = {} | ||
|
|
||
| async def handle_async_request(self, request: httpx.Request) -> httpx.Response: | ||
| self.calls += 1 | ||
| if self.calls == 1: | ||
| # Intentionally do not consume request body. | ||
| return httpx.Response(402, headers={"x-402": "required"}, content=b"challenge") | ||
|
|
||
| self.retry_body = await request.aread() | ||
| self.retry_headers = dict(request.headers) | ||
| return httpx.Response(200, content=b"ok") | ||
|
|
||
|
|
||
| class _PassthroughAsyncTransport(httpx.AsyncBaseTransport): | ||
| def __init__(self) -> None: | ||
| self.calls = 0 | ||
|
|
||
| async def handle_async_request(self, request: httpx.Request) -> httpx.Response: | ||
| self.calls += 1 | ||
| return httpx.Response(200, content=b"ok") | ||
|
|
||
|
|
||
| class _CloseTracking402AsyncTransport(httpx.AsyncBaseTransport): | ||
| def __init__(self) -> None: | ||
| self.calls = 0 | ||
| self.first_response_closed = False | ||
|
|
||
| async def handle_async_request(self, request: httpx.Request) -> httpx.Response: | ||
| self.calls += 1 | ||
| if self.calls == 1: | ||
| response = httpx.Response(402, headers={"x-402": "required"}, content=b"challenge") | ||
| original_aclose = response.aclose | ||
|
|
||
| async def tracked_aclose() -> None: | ||
| self.first_response_closed = True | ||
| await original_aclose() | ||
|
|
||
| response.aclose = tracked_aclose # type: ignore[method-assign] | ||
| return response | ||
| return httpx.Response(200, content=b"ok") | ||
|
|
||
|
|
||
| class _CloseDelegatingAsyncTransport(httpx.AsyncBaseTransport): | ||
| def __init__(self) -> None: | ||
| self.closed = False | ||
|
|
||
| async def handle_async_request(self, request: httpx.Request) -> httpx.Response: | ||
| return httpx.Response(200, content=b"ok") | ||
|
|
||
| async def aclose(self) -> None: | ||
| self.closed = True | ||
|
|
||
|
|
||
| async def _aiter_json() -> AsyncIterator[bytes]: | ||
| yield b'{"prompt":' | ||
| yield b'"hi"}' | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_async_transport_retries_even_if_402_response_short_circuits_body() -> None: | ||
| inner = _ShortCircuit402AsyncTransport() | ||
| transport = AsyncX402Transport(_FakeX402ClientAsync(), inner=inner) | ||
|
|
||
| request = httpx.Request("POST", "https://example.com/v1/chat", content=_aiter_json()) | ||
| response = await transport.handle_async_request(request) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert inner.calls == 2 | ||
| assert inner.retry_headers["x-payment"] == "signed" | ||
| assert inner.retry_body == b'{"prompt":"hi"}' | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good coverage for the streaming-body edge case. A few more scenarios would round out the test suite:
|
||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_async_transport_passes_through_non_402_response() -> None: | ||
| inner = _PassthroughAsyncTransport() | ||
| transport = AsyncX402Transport(_FakeX402ClientAsync(), inner=inner) | ||
|
|
||
| response = await transport.handle_async_request(httpx.Request("GET", "https://example.com/v1/models")) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert inner.calls == 1 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_async_transport_returns_original_402_when_signing_fails() -> None: | ||
| inner = _ShortCircuit402AsyncTransport() | ||
| transport = AsyncX402Transport(_FailingX402ClientAsync(), inner=inner) | ||
|
|
||
| response = await transport.handle_async_request( | ||
| httpx.Request("POST", "https://example.com/v1/chat", content=_aiter_json()) | ||
| ) | ||
|
|
||
| assert response.status_code == 402 | ||
| assert inner.calls == 1 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_async_transport_closes_original_402_before_retry() -> None: | ||
| inner = _CloseTracking402AsyncTransport() | ||
| transport = AsyncX402Transport(_FakeX402ClientAsync(), inner=inner) | ||
|
|
||
| response = await transport.handle_async_request( | ||
| httpx.Request("POST", "https://example.com/v1/chat", content=_aiter_json()) | ||
| ) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert inner.calls == 2 | ||
| assert inner.first_response_closed is True | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_async_transport_aclose_delegates_to_inner_transport() -> None: | ||
| inner = _CloseDelegatingAsyncTransport() | ||
| transport = AsyncX402Transport(_FakeX402ClientAsync(), inner=inner) | ||
|
|
||
| await transport.aclose() | ||
|
|
||
| assert inner.closed is True | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: when
content is None, this falls back tooriginal.content, which will raiseRequestNotReadif the body hasn't been consumed. Since every current call site passescontent=explicitly this is fine in practice, but it's a footgun for future callers.Consider either making
contentrequired, or documenting the precondition more prominently (e.g. "caller must ensure body is materialized whencontentis omitted").