Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/x402_openai/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,25 @@
def _clone_request_with_headers(
original: httpx.Request,
extra_headers: dict[str, str],
*,
content: bytes | None = None,
) -> httpx.Request:
"""Clone *original* and merge *extra_headers* into the copy."""
"""Clone *original* and merge *extra_headers* into the copy.

Parameters
----------
content:
Optional explicit body bytes to use for the cloned request.
When omitted, ``original.content`` is used and must already be
materialized by the caller.
"""
headers = dict(original.headers)
headers.update(extra_headers)
return httpx.Request(
method=original.method,
url=original.url,
headers=headers,
content=original.content,
content=original.content if content is None else content,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: when content is None, this falls back to original.content, which will raise RequestNotRead if the body hasn't been consumed. Since every current call site passes content= explicitly this is fine in practice, but it's a footgun for future callers.

Consider either making content required, or documenting the precondition more prominently (e.g. "caller must ensure body is materialized when content is omitted").

extensions=dict(original.extensions),
)

Expand Down Expand Up @@ -78,7 +88,15 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
logger.exception("x402: payment signing failed")
return response

retry = _clone_request_with_headers(request, payment_headers)
try:
body = request.content
except httpx.RequestNotRead:
# Some transports/proxies can short-circuit with 402 before consuming
# the request body. Ensure we materialize it so the retry is replayable.
body = request.read()

retry = _clone_request_with_headers(request, payment_headers, content=body)
response.close()
return self._inner.handle_request(retry)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original 402 response is never closed before issuing the retry. Since response.read() was called above (L79), the body is consumed but the underlying connection is still held open. This will leak connections under sustained 402 traffic.

Suggested change
return self._inner.handle_request(retry)
retry = _clone_request_with_headers(request, payment_headers, content=body)
response.close()
return self._inner.handle_request(retry)


def close(self) -> None:
Expand Down Expand Up @@ -128,7 +146,15 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
logger.exception("x402: payment signing failed")
return response

retry = _clone_request_with_headers(request, payment_headers)
try:
body = request.content
except httpx.RequestNotRead:
# Some transports/proxies can short-circuit with 402 before consuming
# the request body. Ensure we materialize it so the retry is replayable.
body = await request.aread()

retry = _clone_request_with_headers(request, payment_headers, content=body)
await response.aclose()
return await self._inner.handle_async_request(retry)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue on the async path — the 402 response is never closed before the retry.

Suggested change
return await self._inner.handle_async_request(retry)
retry = _clone_request_with_headers(request, payment_headers, content=body)
await response.aclose()
return await self._inner.handle_async_request(retry)


async def aclose(self) -> None:
Expand Down
295 changes: 294 additions & 1 deletion tests/test_transport.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import httpx
import pytest

if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator

from x402_openai._transport import _clone_request_with_headers
from x402_openai._transport import (
AsyncX402Transport,
X402Transport,
_clone_request_with_headers,
)


def test_clone_request_merges_headers() -> None:
Expand Down Expand Up @@ -37,3 +47,286 @@ def test_clone_request_preserves_extensions() -> None:
cloned = _clone_request_with_headers(original, {"x-payment": "signed"})

assert cloned.extensions == original.extensions


class _FakeX402ClientSync:
def handle_402_response(
self,
headers: dict[str, str],
body: bytes,
) -> tuple[dict[str, str], dict[str, str]]:
assert headers["x-402"] == "required"
assert body == b"challenge"
return {"x-payment": "signed"}, {}


class _FailingX402ClientSync:
def handle_402_response(
self,
headers: dict[str, str],
body: bytes,
) -> tuple[dict[str, str], dict[str, str]]:
raise RuntimeError("signing failed")


class _ShortCircuit402Transport(httpx.BaseTransport):
def __init__(self) -> None:
self.calls = 0
self.retry_body = b""
self.retry_headers: dict[str, str] = {}

def handle_request(self, request: httpx.Request) -> httpx.Response:
self.calls += 1
if self.calls == 1:
# Intentionally do not consume request body.
return httpx.Response(402, headers={"x-402": "required"}, content=b"challenge")

self.retry_body = request.read()
self.retry_headers = dict(request.headers)
return httpx.Response(200, content=b"ok")


class _PassthroughTransport(httpx.BaseTransport):
def __init__(self) -> None:
self.calls = 0

def handle_request(self, request: httpx.Request) -> httpx.Response:
self.calls += 1
return httpx.Response(200, content=b"ok")


class _CloseTracking402Transport(httpx.BaseTransport):
def __init__(self) -> None:
self.calls = 0
self.first_response_closed = False

def handle_request(self, request: httpx.Request) -> httpx.Response:
self.calls += 1
if self.calls == 1:
response = httpx.Response(402, headers={"x-402": "required"}, content=b"challenge")
original_close = response.close

def tracked_close() -> None:
self.first_response_closed = True
original_close()

response.close = tracked_close # type: ignore[method-assign]
return response
return httpx.Response(200, content=b"ok")


class _CloseDelegatingTransport(httpx.BaseTransport):
def __init__(self) -> None:
self.closed = False

def handle_request(self, request: httpx.Request) -> httpx.Response:
return httpx.Response(200, content=b"ok")

def close(self) -> None:
self.closed = True


def _iter_json() -> Iterator[bytes]:
yield b'{"prompt":'
yield b'"hi"}'


def test_sync_transport_retries_even_if_402_response_short_circuits_body() -> None:
inner = _ShortCircuit402Transport()
transport = X402Transport(_FakeX402ClientSync(), inner=inner)

request = httpx.Request("POST", "https://example.com/v1/chat", content=_iter_json())
response = transport.handle_request(request)

assert response.status_code == 200
assert inner.calls == 2
assert inner.retry_headers["x-payment"] == "signed"
assert inner.retry_body == b'{"prompt":"hi"}'


def test_sync_transport_passes_through_non_402_response() -> None:
inner = _PassthroughTransport()
transport = X402Transport(_FakeX402ClientSync(), inner=inner)

response = transport.handle_request(httpx.Request("GET", "https://example.com/v1/models"))

assert response.status_code == 200
assert inner.calls == 1


def test_sync_transport_returns_original_402_when_signing_fails() -> None:
inner = _ShortCircuit402Transport()
transport = X402Transport(_FailingX402ClientSync(), inner=inner)

response = transport.handle_request(
httpx.Request("POST", "https://example.com/v1/chat", content=_iter_json())
)

assert response.status_code == 402
assert inner.calls == 1


def test_sync_transport_closes_original_402_before_retry() -> None:
inner = _CloseTracking402Transport()
transport = X402Transport(_FakeX402ClientSync(), inner=inner)

response = transport.handle_request(
httpx.Request("POST", "https://example.com/v1/chat", content=_iter_json())
)

assert response.status_code == 200
assert inner.calls == 2
assert inner.first_response_closed is True


def test_sync_transport_close_delegates_to_inner_transport() -> None:
inner = _CloseDelegatingTransport()
transport = X402Transport(_FakeX402ClientSync(), inner=inner)

transport.close()

assert inner.closed is True


class _FakeX402ClientAsync:
async def handle_402_response(
self,
headers: dict[str, str],
body: bytes,
) -> tuple[dict[str, str], dict[str, str]]:
assert headers["x-402"] == "required"
assert body == b"challenge"
return {"x-payment": "signed"}, {}


class _FailingX402ClientAsync:
async def handle_402_response(
self,
headers: dict[str, str],
body: bytes,
) -> tuple[dict[str, str], dict[str, str]]:
raise RuntimeError("signing failed")


class _ShortCircuit402AsyncTransport(httpx.AsyncBaseTransport):
def __init__(self) -> None:
self.calls = 0
self.retry_body = b""
self.retry_headers: dict[str, str] = {}

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
self.calls += 1
if self.calls == 1:
# Intentionally do not consume request body.
return httpx.Response(402, headers={"x-402": "required"}, content=b"challenge")

self.retry_body = await request.aread()
self.retry_headers = dict(request.headers)
return httpx.Response(200, content=b"ok")


class _PassthroughAsyncTransport(httpx.AsyncBaseTransport):
def __init__(self) -> None:
self.calls = 0

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
self.calls += 1
return httpx.Response(200, content=b"ok")


class _CloseTracking402AsyncTransport(httpx.AsyncBaseTransport):
def __init__(self) -> None:
self.calls = 0
self.first_response_closed = False

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
self.calls += 1
if self.calls == 1:
response = httpx.Response(402, headers={"x-402": "required"}, content=b"challenge")
original_aclose = response.aclose

async def tracked_aclose() -> None:
self.first_response_closed = True
await original_aclose()

response.aclose = tracked_aclose # type: ignore[method-assign]
return response
return httpx.Response(200, content=b"ok")


class _CloseDelegatingAsyncTransport(httpx.AsyncBaseTransport):
def __init__(self) -> None:
self.closed = False

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
return httpx.Response(200, content=b"ok")

async def aclose(self) -> None:
self.closed = True


async def _aiter_json() -> AsyncIterator[bytes]:
yield b'{"prompt":'
yield b'"hi"}'


@pytest.mark.asyncio
async def test_async_transport_retries_even_if_402_response_short_circuits_body() -> None:
inner = _ShortCircuit402AsyncTransport()
transport = AsyncX402Transport(_FakeX402ClientAsync(), inner=inner)

request = httpx.Request("POST", "https://example.com/v1/chat", content=_aiter_json())
response = await transport.handle_async_request(request)

assert response.status_code == 200
assert inner.calls == 2
assert inner.retry_headers["x-payment"] == "signed"
assert inner.retry_body == b'{"prompt":"hi"}'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good coverage for the streaming-body edge case. A few more scenarios would round out the test suite:

  • Happy path: non-402 response passes through untouched (ensures the transport doesn't accidentally mutate normal traffic).
  • Exception fallback: handle_402_response raises → original 402 is returned as-is.
  • close()/aclose() delegation: verifies the lifecycle methods forward to the inner transport.



@pytest.mark.asyncio
async def test_async_transport_passes_through_non_402_response() -> None:
inner = _PassthroughAsyncTransport()
transport = AsyncX402Transport(_FakeX402ClientAsync(), inner=inner)

response = await transport.handle_async_request(httpx.Request("GET", "https://example.com/v1/models"))

assert response.status_code == 200
assert inner.calls == 1


@pytest.mark.asyncio
async def test_async_transport_returns_original_402_when_signing_fails() -> None:
inner = _ShortCircuit402AsyncTransport()
transport = AsyncX402Transport(_FailingX402ClientAsync(), inner=inner)

response = await transport.handle_async_request(
httpx.Request("POST", "https://example.com/v1/chat", content=_aiter_json())
)

assert response.status_code == 402
assert inner.calls == 1


@pytest.mark.asyncio
async def test_async_transport_closes_original_402_before_retry() -> None:
inner = _CloseTracking402AsyncTransport()
transport = AsyncX402Transport(_FakeX402ClientAsync(), inner=inner)

response = await transport.handle_async_request(
httpx.Request("POST", "https://example.com/v1/chat", content=_aiter_json())
)

assert response.status_code == 200
assert inner.calls == 2
assert inner.first_response_closed is True


@pytest.mark.asyncio
async def test_async_transport_aclose_delegates_to_inner_transport() -> None:
inner = _CloseDelegatingAsyncTransport()
transport = AsyncX402Transport(_FakeX402ClientAsync(), inner=inner)

await transport.aclose()

assert inner.closed is True