Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions lib/req_llm/providers/google.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1134,12 +1134,14 @@ defmodule ReqLLM.Providers.Google do
raise ReqLLM.Error.Invalid.Parameter, parameter: "schema: #{message}"
end

defp normalize_embedding_response(%{"embedding" => %{"values" => values}})
defp normalize_embedding_response(%{"embedding" => %{"values" => values}} = body)
when is_list(values) do
%{"data" => [%{"index" => 0, "embedding" => values}]}
|> maybe_put_embedding_usage_metadata(body)
end

defp normalize_embedding_response(%{"embeddings" => embeddings}) when is_list(embeddings) do
defp normalize_embedding_response(%{"embeddings" => embeddings} = body)
when is_list(embeddings) do
data =
embeddings
|> Enum.with_index()
Expand All @@ -1153,10 +1155,21 @@ defmodule ReqLLM.Providers.Google do
end)

%{"data" => data}
|> maybe_put_embedding_usage_metadata(body)
end

defp normalize_embedding_response(other), do: other

defp maybe_put_embedding_usage_metadata(normalized, body) do
case Map.get(body, "usageMetadata") do
usage_metadata when is_map(usage_metadata) ->
Map.put(normalized, "usageMetadata", usage_metadata)

_ ->
normalized
end
end

@impl ReqLLM.Provider
def decode_response({req, resp}) do
case resp.status do
Expand Down
10 changes: 9 additions & 1 deletion test/providers/google_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,10 @@ defmodule ReqLLM.Providers.GoogleTest do
embedding_response = %{
"embedding" => %{
"values" => [0.1, -0.2, 0.3, 0.4, -0.5]
},
"usageMetadata" => %{
"promptTokenCount" => 2,
"totalTokenCount" => 2
}
}

Expand All @@ -764,7 +768,11 @@ defmodule ReqLLM.Providers.GoogleTest do
assert req == mock_req

assert resp.body == %{
"data" => [%{"index" => 0, "embedding" => [0.1, -0.2, 0.3, 0.4, -0.5]}]
"data" => [%{"index" => 0, "embedding" => [0.1, -0.2, 0.3, 0.4, -0.5]}],
"usageMetadata" => %{
"promptTokenCount" => 2,
"totalTokenCount" => 2
}
}
end

Expand Down
42 changes: 42 additions & 0 deletions test/req_llm/embedding_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,48 @@ defmodule ReqLLM.EmbeddingTest do
end
end

describe "embed/3 - Google usage metadata" do
setup do
Req.Test.stub(__MODULE__.GoogleEmbedUsage, fn conn ->
Req.Test.json(conn, %{
"embedding" => %{
"values" => [0.1, -0.2, 0.3]
},
"usageMetadata" => %{
"promptTokenCount" => 2,
"totalTokenCount" => 2
}
})
end)

setup_telemetry()
end

test "emits telemetry and returns usage for embeddings" do
{:ok, %{embedding: embedding, usage: usage}} =
Embedding.embed(
"google:gemini-embedding-001",
"Hello world",
api_key: "test-key",
return_usage: true,
req_http_options: [plug: {Req.Test, __MODULE__.GoogleEmbedUsage}]
)

assert embedding == [0.1, -0.2, 0.3]
assert usage.input == 2
assert usage.total_tokens == 2

assert_receive {:telemetry_event, [:req_llm, :token_usage], measurements,
%{model: %LLMDB.Model{provider: :google, id: "gemini-embedding-001"}} =
metadata}

assert measurements.tokens.input == 2
assert measurements.tokens.total_tokens == 2
assert metadata.model.provider == :google
assert metadata.model.id == "gemini-embedding-001"
end
end

describe "embed_many/3 - basic functionality" do
test "validates model before attempting embedding" do
case Embedding.validate_model("openai:text-embedding-3-small") do
Expand Down
Loading