From bec4c63e8fafcedf2632829aaa927a8b33b39ca3 Mon Sep 17 00:00:00 2001 From: Mike Hostetler <84222+mikehostetler@users.noreply.github.com> Date: Sun, 22 Mar 2026 19:50:58 -0500 Subject: [PATCH] fix: retry transient streaming transport failures --- lib/req_llm/streaming/finch_client.ex | 13 ++- lib/req_llm/streaming/retry.ex | 126 ++++++++++++++++++++++++++ test/req_llm/streaming/retry_test.exs | 94 +++++++++++++++++++ 3 files changed, 229 insertions(+), 4 deletions(-) create mode 100644 lib/req_llm/streaming/retry.ex create mode 100644 test/req_llm/streaming/retry_test.exs diff --git a/lib/req_llm/streaming/finch_client.ex b/lib/req_llm/streaming/finch_client.ex index 2c869333c..4bc0dc81e 100644 --- a/lib/req_llm/streaming/finch_client.ex +++ b/lib/req_llm/streaming/finch_client.ex @@ -26,7 +26,7 @@ defmodule ReqLLM.Streaming.FinchClient do requests with proper authentication, headers, and request body formatting. """ - alias ReqLLM.Streaming.Fixtures + alias ReqLLM.Streaming.{Fixtures, Retry} alias ReqLLM.Streaming.Fixtures.HTTPContext alias ReqLLM.StreamServer @@ -193,13 +193,18 @@ defmodule ReqLLM.Streaming.FinchClient do receive_timeout = Keyword.get(opts, :receive_timeout, default_timeout) try do - case Finch.stream(finch_request, finch_name, :ok, finch_stream_callback, - receive_timeout: receive_timeout + case Retry.stream( + finch_request, + finch_name, + :ok, + finch_stream_callback, + receive_timeout: receive_timeout, + max_retries: Keyword.get(opts, :max_retries, 3) ) do {:ok, _} -> :ok - {:error, reason, _partial_acc} -> + {:error, reason, _callback_acc} -> Logger.error("Finch streaming failed", reason: reason) safe_http_event(stream_server_pid, {:error, reason}) {:error, reason} diff --git a/lib/req_llm/streaming/retry.ex b/lib/req_llm/streaming/retry.ex new file mode 100644 index 000000000..3ce2cb98d --- /dev/null +++ b/lib/req_llm/streaming/retry.ex @@ -0,0 +1,126 @@ +defmodule ReqLLM.Streaming.Retry do + @moduledoc """ + Retry wrapper for Finch streaming requests. + + Streaming retries are intentionally conservative: only transient transport + failures that happen before any response body data is emitted are retried. + This avoids duplicating partial model output when a stream has already begun. + """ + + require Logger + + @retryable_reasons [:closed, :timeout, :econnrefused] + + @type callback_acc :: term() + @type callback :: (term(), callback_acc() -> callback_acc()) + @type stream_fun :: + (Finch.Request.t(), atom(), term(), (term(), term() -> term()), keyword() -> + {:ok, term()} | {:error, term(), term()}) + + @spec stream( + Finch.Request.t(), + atom(), + callback_acc(), + callback(), + keyword(), + stream_fun() + ) :: {:ok, callback_acc()} | {:error, term(), callback_acc()} + def stream(request, finch_name, acc, callback, opts, stream_fun \\ &Finch.stream/5) do + max_retries = Keyword.get(opts, :max_retries, 3) + stream_opts = Keyword.take(opts, [:receive_timeout]) + + do_stream(request, finch_name, acc, callback, stream_opts, stream_fun, max_retries, 0) + end + + defp do_stream( + request, + finch_name, + acc, + callback, + stream_opts, + stream_fun, + max_retries, + attempt + ) do + initial_acc = %{callback_acc: acc, data_received?: false} + wrapped_callback = fn event, wrapped_acc -> apply_callback(event, wrapped_acc, callback) end + + case stream_fun.(request, finch_name, initial_acc, wrapped_callback, stream_opts) do + {:ok, %{callback_acc: callback_acc}} -> + {:ok, callback_acc} + + {:error, reason, %{data_received?: false, callback_acc: callback_acc}} + when attempt < max_retries -> + maybe_retry( + request, + finch_name, + acc, + callback, + stream_opts, + stream_fun, + max_retries, + attempt, + callback_acc, + reason + ) + + {:error, reason, %{callback_acc: callback_acc}} -> + {:error, reason, callback_acc} + end + end + + defp maybe_retry( + request, + finch_name, + acc, + callback, + stream_opts, + stream_fun, + max_retries, + attempt, + callback_acc, + reason + ) do + if retryable_reason?(reason) do + log_retry(reason, attempt + 1, max_retries) + + do_stream( + request, + finch_name, + acc, + callback, + stream_opts, + stream_fun, + max_retries, + attempt + 1 + ) + else + {:error, reason, callback_acc} + end + end + + defp apply_callback({:data, _} = event, %{callback_acc: callback_acc} = wrapped_acc, callback) do + %{wrapped_acc | callback_acc: callback.(event, callback_acc), data_received?: true} + end + + defp apply_callback(event, %{callback_acc: callback_acc} = wrapped_acc, callback) do + %{wrapped_acc | callback_acc: callback.(event, callback_acc)} + end + + defp retryable_reason?(%Mint.TransportError{reason: reason}) when reason in @retryable_reasons, + do: true + + defp retryable_reason?(%Req.TransportError{reason: reason}) when reason in @retryable_reasons, + do: true + + defp retryable_reason?(_reason), do: false + + defp log_retry(reason, attempt, max_retries) do + Logger.warning( + "Retrying streaming request after transient transport error", + reason: inspect(reason), + attempt: attempt, + max_retries: max_retries + ) + end +end diff --git a/test/req_llm/streaming/retry_test.exs b/test/req_llm/streaming/retry_test.exs new file mode 100644 index 000000000..f88ff9756 --- /dev/null +++ b/test/req_llm/streaming/retry_test.exs @@ -0,0 +1,94 @@ +defmodule ReqLLM.Streaming.RetryTest do + use ExUnit.Case, async: true + + alias ReqLLM.Streaming.Retry + + test "retries transient transport errors before any data is received" do + {:ok, counter} = Agent.start_link(fn -> 0 end) + + stream_fun = fn _request, _finch_name, acc, callback, _opts -> + attempt = Agent.get_and_update(counter, fn current -> {current + 1, current + 1} end) + acc = callback.({:status, 200}, acc) + acc = callback.({:headers, [{"content-type", "text/event-stream"}]}, acc) + + case attempt do + 1 -> + {:error, %Mint.TransportError{reason: :closed}, acc} + + 2 -> + acc = callback.({:data, "hello"}, acc) + acc = callback.(:done, acc) + {:ok, acc} + end + end + + callback = fn event, acc -> [event | acc] end + + assert {:ok, events} = + Retry.stream( + Finch.build(:post, "https://example.com/stream"), + ReqLLM.Finch, + [], + callback, + [max_retries: 1, receive_timeout: 1_000], + stream_fun + ) + + assert Agent.get(counter, & &1) == 2 + + assert Enum.reverse(events) == [ + {:status, 200}, + {:headers, [{"content-type", "text/event-stream"}]}, + {:data, "hello"}, + :done + ] + end + + test "does not retry transient transport errors after data has been received" do + {:ok, counter} = Agent.start_link(fn -> 0 end) + + stream_fun = fn _request, _finch_name, acc, callback, _opts -> + Agent.update(counter, &(&1 + 1)) + acc = callback.({:data, "partial"}, acc) + {:error, %Mint.TransportError{reason: :timeout}, acc} + end + + callback = fn event, acc -> [event | acc] end + + assert {:error, %Mint.TransportError{reason: :timeout}, events} = + Retry.stream( + Finch.build(:post, "https://example.com/stream"), + ReqLLM.Finch, + [], + callback, + [max_retries: 3, receive_timeout: 1_000], + stream_fun + ) + + assert Agent.get(counter, & &1) == 1 + assert Enum.reverse(events) == [{:data, "partial"}] + end + + test "does not retry non-retryable transport errors" do + {:ok, counter} = Agent.start_link(fn -> 0 end) + + stream_fun = fn _request, _finch_name, acc, _callback, _opts -> + Agent.update(counter, &(&1 + 1)) + {:error, %Mint.TransportError{reason: :protocol_not_negotiated}, acc} + end + + callback = fn event, acc -> [event | acc] end + + assert {:error, %Mint.TransportError{reason: :protocol_not_negotiated}, []} = + Retry.stream( + Finch.build(:post, "https://example.com/stream"), + ReqLLM.Finch, + [], + callback, + [max_retries: 3, receive_timeout: 1_000], + stream_fun + ) + + assert Agent.get(counter, & &1) == 1 + end +end