Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions lib/req_llm/streaming/finch_client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ defmodule ReqLLM.Streaming.FinchClient do
requests with proper authentication, headers, and request body formatting.
"""

alias ReqLLM.Streaming.Fixtures
alias ReqLLM.Streaming.{Fixtures, Retry}
alias ReqLLM.Streaming.Fixtures.HTTPContext
alias ReqLLM.StreamServer

Expand Down Expand Up @@ -193,13 +193,18 @@ defmodule ReqLLM.Streaming.FinchClient do
receive_timeout = Keyword.get(opts, :receive_timeout, default_timeout)

try do
case Finch.stream(finch_request, finch_name, :ok, finch_stream_callback,
receive_timeout: receive_timeout
case Retry.stream(
finch_request,
finch_name,
:ok,
finch_stream_callback,
receive_timeout: receive_timeout,
max_retries: Keyword.get(opts, :max_retries, 3)
) do
{:ok, _} ->
:ok

{:error, reason, _partial_acc} ->
{:error, reason, _callback_acc} ->
Logger.error("Finch streaming failed", reason: reason)
safe_http_event(stream_server_pid, {:error, reason})
{:error, reason}
Expand Down
126 changes: 126 additions & 0 deletions lib/req_llm/streaming/retry.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
defmodule ReqLLM.Streaming.Retry do
@moduledoc """
Retry wrapper for Finch streaming requests.

Streaming retries are intentionally conservative: only transient transport
failures that happen before any response body data is emitted are retried.
This avoids duplicating partial model output when a stream has already begun.
"""

require Logger

@retryable_reasons [:closed, :timeout, :econnrefused]

@type callback_acc :: term()
@type callback :: (term(), callback_acc() -> callback_acc())
@type stream_fun ::
(Finch.Request.t(), atom(), term(), (term(), term() -> term()), keyword() ->
{:ok, term()} | {:error, term(), term()})

@spec stream(
Finch.Request.t(),
atom(),
callback_acc(),
callback(),
keyword(),
stream_fun()
) :: {:ok, callback_acc()} | {:error, term(), callback_acc()}
def stream(request, finch_name, acc, callback, opts, stream_fun \\ &Finch.stream/5) do
max_retries = Keyword.get(opts, :max_retries, 3)
stream_opts = Keyword.take(opts, [:receive_timeout])

do_stream(request, finch_name, acc, callback, stream_opts, stream_fun, max_retries, 0)
end

defp do_stream(
request,
finch_name,
acc,
callback,
stream_opts,
stream_fun,
max_retries,
attempt
) do
initial_acc = %{callback_acc: acc, data_received?: false}
wrapped_callback = fn event, wrapped_acc -> apply_callback(event, wrapped_acc, callback) end

case stream_fun.(request, finch_name, initial_acc, wrapped_callback, stream_opts) do
{:ok, %{callback_acc: callback_acc}} ->
{:ok, callback_acc}

{:error, reason, %{data_received?: false, callback_acc: callback_acc}}
when attempt < max_retries ->
maybe_retry(
request,
finch_name,
acc,
callback,
stream_opts,
stream_fun,
max_retries,
attempt,
callback_acc,
reason
)

{:error, reason, %{callback_acc: callback_acc}} ->
{:error, reason, callback_acc}
end
end

defp maybe_retry(
request,
finch_name,
acc,
callback,
stream_opts,
stream_fun,
max_retries,
attempt,
callback_acc,
reason
) do
if retryable_reason?(reason) do
log_retry(reason, attempt + 1, max_retries)

do_stream(
request,
finch_name,
acc,
callback,
stream_opts,
stream_fun,
max_retries,
attempt + 1
)
else
{:error, reason, callback_acc}
end
end

defp apply_callback({:data, _} = event, %{callback_acc: callback_acc} = wrapped_acc, callback) do
%{wrapped_acc | callback_acc: callback.(event, callback_acc), data_received?: true}
end

defp apply_callback(event, %{callback_acc: callback_acc} = wrapped_acc, callback) do
%{wrapped_acc | callback_acc: callback.(event, callback_acc)}
end

defp retryable_reason?(%Mint.TransportError{reason: reason}) when reason in @retryable_reasons,
do: true

defp retryable_reason?(%Req.TransportError{reason: reason}) when reason in @retryable_reasons,
do: true

defp retryable_reason?(_reason), do: false

defp log_retry(reason, attempt, max_retries) do
Logger.warning(
"Retrying streaming request after transient transport error",
reason: inspect(reason),
attempt: attempt,
max_retries: max_retries
)
end
end
94 changes: 94 additions & 0 deletions test/req_llm/streaming/retry_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
defmodule ReqLLM.Streaming.RetryTest do
use ExUnit.Case, async: true

alias ReqLLM.Streaming.Retry

test "retries transient transport errors before any data is received" do
{:ok, counter} = Agent.start_link(fn -> 0 end)

stream_fun = fn _request, _finch_name, acc, callback, _opts ->
attempt = Agent.get_and_update(counter, fn current -> {current + 1, current + 1} end)
acc = callback.({:status, 200}, acc)
acc = callback.({:headers, [{"content-type", "text/event-stream"}]}, acc)

case attempt do
1 ->
{:error, %Mint.TransportError{reason: :closed}, acc}

2 ->
acc = callback.({:data, "hello"}, acc)
acc = callback.(:done, acc)
{:ok, acc}
end
end

callback = fn event, acc -> [event | acc] end

assert {:ok, events} =
Retry.stream(
Finch.build(:post, "https://example.com/stream"),
ReqLLM.Finch,
[],
callback,
[max_retries: 1, receive_timeout: 1_000],
stream_fun
)

assert Agent.get(counter, & &1) == 2

assert Enum.reverse(events) == [
{:status, 200},
{:headers, [{"content-type", "text/event-stream"}]},
{:data, "hello"},
:done
]
end

test "does not retry transient transport errors after data has been received" do
{:ok, counter} = Agent.start_link(fn -> 0 end)

stream_fun = fn _request, _finch_name, acc, callback, _opts ->
Agent.update(counter, &(&1 + 1))
acc = callback.({:data, "partial"}, acc)
{:error, %Mint.TransportError{reason: :timeout}, acc}
end

callback = fn event, acc -> [event | acc] end

assert {:error, %Mint.TransportError{reason: :timeout}, events} =
Retry.stream(
Finch.build(:post, "https://example.com/stream"),
ReqLLM.Finch,
[],
callback,
[max_retries: 3, receive_timeout: 1_000],
stream_fun
)

assert Agent.get(counter, & &1) == 1
assert Enum.reverse(events) == [{:data, "partial"}]
end

test "does not retry non-retryable transport errors" do
{:ok, counter} = Agent.start_link(fn -> 0 end)

stream_fun = fn _request, _finch_name, acc, _callback, _opts ->
Agent.update(counter, &(&1 + 1))
{:error, %Mint.TransportError{reason: :protocol_not_negotiated}, acc}
end

callback = fn event, acc -> [event | acc] end

assert {:error, %Mint.TransportError{reason: :protocol_not_negotiated}, []} =
Retry.stream(
Finch.build(:post, "https://example.com/stream"),
ReqLLM.Finch,
[],
callback,
[max_retries: 3, receive_timeout: 1_000],
stream_fun
)

assert Agent.get(counter, & &1) == 1
end
end
Loading