diff --git a/contrib/sarvam-ai/CAPABILITIES.md b/contrib/sarvam-ai/CAPABILITIES.md new file mode 100644 index 000000000..19e765c84 --- /dev/null +++ b/contrib/sarvam-ai/CAPABILITIES.md @@ -0,0 +1,221 @@ +# Sarvam AI - ADK Integration Capabilities + +## Overview + +The Sarvam AI module provides a comprehensive, production-grade integration of Sarvam AI services into the Google Agent Development Kit (ADK) for Java. It spans five service domains -- Chat, Speech-to-Text, Text-to-Speech, Vision, and Live Connections -- covering both REST and WebSocket protocols with full observability, resilience, and multi-turn agentic support. + +**Module path:** `contrib/sarvam-ai` +**Package:** `com.google.adk.models.sarvamai` +**Branch:** `sarvam-ai` + +--- + +## 1. Chat Completions (LLM) + +**Class:** `SarvamAi` extends `BaseLlm` +**Endpoint:** `POST /v1/chat/completions` (OpenAI-compatible) + +| Capability | Details | +|---|---| +| Blocking (non-streaming) | Full request/response cycle via `generateContent(request, false)` | +| SSE Streaming | Real-time token-by-token delivery via `generateContent(request, true)` with backpressure (RxJava `Flowable`) | +| Function / Tool Calling | ADK `FunctionDeclaration` serialized to OpenAI `tools` JSON with `tool_choice: auto` | +| Multi-turn Tool History | Prior `tool_calls` correctly formatted as assistant messages with `tool_call_id`, `function.name`, `function.arguments`; tool responses sent as `role: tool` | +| Streaming Function Calls | Chunked `name` and `arguments` accumulated across SSE deltas, emitted as final `FunctionCall` Part | +| Token Usage Tracking | `prompt_tokens`, `completion_tokens`, `total_tokens` extracted for both blocking and streaming modes. Streaming uses `stream_options: {"include_usage": true}` | +| System Instructions | ADK `GenerateContentConfig.systemInstruction` mapped to OpenAI `system` role message | +| Temperature Control | Forwarded from `GenerateContentConfig.temperature` (default 0.7) | +| Max Output Tokens | `GenerateContentConfig.maxOutputTokens` forwarded as `max_tokens` | +| Top-P Sampling | Configurable via `SarvamAiConfig.topP()` | +| Frequency / Presence Penalty | Configurable via `SarvamAiConfig` builder | +| Reasoning Effort | Sarvam-specific `reasoning_effort` parameter (low / medium / high) | +| Wiki Grounding | Sarvam-specific `wiki_grounding` toggle for factual grounding | +| Role Translation | ADK `model` -> OpenAI `assistant`, `user` -> `user`, `functionResponse` -> `tool` | +| Schema Normalization | Type strings lowercased, nested `items.properties` recursively normalized for OpenAI schema compatibility | +| Graceful Degradation | Empty choices return empty text response instead of crashing | + +### Dual Implementation + +| Implementation | Location | Use Case | +|---|---|---| +| `SarvamBaseLM` | `core/src/main/java/.../models/SarvamBaseLM.java` | Lightweight, env-var driven. Used by `AgentModelConfig` and `LlmRegistry` for `Sarvam\|model` config strings | +| `SarvamAi` | `contrib/sarvam-ai/src/.../SarvamAi.java` | Full-featured, Builder-pattern, OkHttp-based. Supports all chat parameters plus subservice access | + +--- + +## 2. Speech-to-Text (STT) + +**Class:** `SarvamSttService` implements `TranscriptionService` +**Model:** `saaras:v3` + +| Capability | Details | +|---|---| +| REST Synchronous | `transcribe(byte[] audioData, TranscriptionConfig)` via `POST /speech-to-text` with multipart/form-data | +| REST Async | `transcribeAsync()` executes on RxJava IO scheduler | +| WebSocket Streaming | Real-time streaming via `wss://api.sarvam.ai/speech-to-text/streaming` with VAD (Voice Activity Detection) signals | +| Transcription Modes | `transcribe`, `translate`, `verbatim`, `translit`, `codemix` | +| Language Detection | Auto-detection supported; explicit BCP-47 codes (e.g., `hi-IN`, `en-IN`) also accepted | +| VAD Signals | `speech_start` and `speech_end` events for voice activity boundaries | +| ADK TranscriptionService | Full implementation of ADK's `TranscriptionService` interface including `isAvailable()`, `getServiceType()`, `getHealth()` | + +--- + +## 3. Text-to-Speech (TTS) + +**Class:** `SarvamTtsService` +**Model:** `bulbul:v3` + +| Capability | Details | +|---|---| +| REST Synchronous | `synthesize(text, languageCode)` returns decoded WAV audio bytes | +| REST Async | `synthesizeAsync()` on IO scheduler | +| WebSocket Streaming | `synthesizeStream()` via `wss://api.sarvam.ai/text-to-speech/streaming` for low-latency progressive audio chunk delivery | +| 30+ Speaker Voices | Configurable via `SarvamAiConfig.ttsSpeaker()` (default: `shubh`) | +| Pace Control | Adjustable speech pace (0.5x to 2.0x) | +| Sample Rate | Configurable output sample rate | +| Base64 Decoding | Audio chunks automatically decoded from base64 to raw bytes | +| WebSocket Lifecycle | Config frame -> text frame -> flush frame -> audio chunks -> final event -> close | + +--- + +## 4. Vision / Document Intelligence + +**Class:** `SarvamVisionService` +**Model:** Sarvam Vision 3B VLM + +| Capability | Details | +|---|---| +| Multi-Language OCR | 23 languages (22 Indian + English) | +| Input Formats | PDF, PNG, JPG, ZIP | +| Output Formats | HTML or Markdown | +| Async Job Pipeline | `createJob` -> `uploadDocument` (presigned URL) -> `startJob` -> `getJobStatus` (poll) -> `downloadResults` | +| Convenience Method | `processDocument(filePath, languageCode, outputFormat)` runs the full pipeline with adaptive exponential backoff polling | +| Polling Backoff | Starts at 2s, doubles up to 10s cap, max 60 polls (~2 min timeout) | + +--- + +## 5. Live Bidirectional Connection + +**Class:** `SarvamAiLlmConnection` implements `BaseLlmConnection` + +| Capability | Details | +|---|---| +| Multi-Turn Context | Maintains conversation history across turns, accumulates full model responses | +| sendHistory | Replace full conversation context | +| sendContent | Append a single turn and trigger streaming response | +| receive | Returns `Flowable` via `PublishSubject` for reactive consumers | +| Thread Safety | History list synchronized for concurrent access | +| Realtime Guard | `sendRealtime(Blob)` throws `UnsupportedOperationException` with guidance to use STT/TTS services | + +--- + +## 6. Resilience & Configuration + +### Retry with Exponential Backoff + +**Class:** `SarvamRetryInterceptor` (OkHttp `Interceptor`) + +| Parameter | Value | +|---|---| +| Retryable codes | 429 (rate limit), 503, 5xx (server errors) | +| Base delay | 500ms | +| Max delay | 30s | +| Strategy | Exponential backoff with 20% jitter | +| Default max retries | 3 | + +### Immutable Configuration + +**Class:** `SarvamAiConfig` (Builder pattern) + +| Parameter | Default | +|---|---| +| Chat endpoint | `https://api.sarvam.ai/v1/chat/completions` | +| STT endpoint | `https://api.sarvam.ai/speech-to-text` | +| STT WebSocket | `wss://api.sarvam.ai/speech-to-text/streaming` | +| TTS endpoint | `https://api.sarvam.ai/text-to-speech` | +| TTS WebSocket | `wss://api.sarvam.ai/text-to-speech/streaming` | +| Vision endpoint | `https://api.sarvam.ai/document-intelligence` | +| Connect timeout | 30s | +| Read timeout | 120s | +| Max retries | 3 | +| API key resolution | Explicit value > `SARVAM_API_KEY` env var | + +### Structured Error Handling + +**Class:** `SarvamAiException` extends `RuntimeException` + +| Field | Purpose | +|---|---| +| `statusCode` | HTTP status code from API | +| `errorCode` | Sarvam-specific error code | +| `requestId` | Sarvam request ID for support tracing | +| `isRetryable()` | Programmatic check (429, 503, 5xx) | + +--- + +## 7. Authentication + +| Method | Header | Used By | +|---|---|---| +| API Subscription Key | `api-subscription-key: ` | `SarvamAi`, STT, TTS, Vision (contrib module) | +| Bearer Token | `Authorization: Bearer ` | `SarvamBaseLM` (core module, OpenAI-compatible) | +| Key Resolution | `SARVAM_API_KEY` env var or explicit via Builder | Both | +| Fail-Fast Validation | Warning logged at construction if key is missing | `SarvamBaseLM` | + +--- + +## 8. Test Coverage + +| Test Class | Tests | Scope | +|---|---|---| +| `SarvamBaseLMTest` | 10 | Response parsing (text, null, tool calls), construction, connection type | +| `SarvamAiTest` | - | Chat completion blocking and streaming | +| `SarvamAiConfigTest` | - | Config builder validation, defaults, env var resolution | +| `ChatRequestTest` | - | Request serialization from LlmRequest | +| `SarvamSttServiceTest` | - | STT REST and WebSocket transcription | +| `SarvamTtsServiceTest` | - | TTS REST and WebSocket synthesis | +| `SarvamRetryInterceptorTest` | - | Retry logic, delay calculation, jitter | +| `SarvamIntegrationTest` (rae) | 20 | End-to-end config wiring across properties, YAML, LlmRegistry | + +--- + +## 9. RAE Integration (Consumer Project) + +| Integration Point | Mechanism | File | +|---|---|---| +| Code-based agents | `AgentModelConfig` recognizes `Sarvam\|` prefix, instantiates `SarvamBaseLM` | `AgentModelConfig.java` | +| YAML-based agents | `LlmRegistry.registerLlm("Sarvam\\|.*", ...)` factory | `ApplicationRegistry.java` | +| Model metadata | `sarvam:` provider in `models.yaml` with feature declarations | `models.yaml` | +| Config format | `Sarvam\|sarvam-m` -- single string works across both paths | `agent-models.properties` + `*.yaml` | +| Global coverage | 43 code-based + 28 YAML agent configs switched to Sarvam | All agent config files | + +--- + +## Architecture Summary + +``` +contrib/sarvam-ai/ + src/main/java/com/google/adk/models/sarvamai/ + SarvamAi.java # BaseLlm (chat, Builder pattern, OkHttp) + SarvamAiConfig.java # Immutable config for all services + SarvamAiException.java # Structured error with status/code/requestId + SarvamAiLlmConnection.java # Live bidirectional multi-turn connection + SarvamRetryInterceptor.java # Exponential backoff with jitter + chat/ + ChatRequest.java # OpenAI-compatible request model + ChatResponse.java # Response deserialization + ChatChoice.java # Choice wrapper + ChatMessage.java # Message model + ChatUsage.java # Token usage tracking + stt/ + SarvamSttService.java # REST + WebSocket STT (TranscriptionService) + tts/ + SarvamTtsService.java # REST + WebSocket TTS + TtsRequest.java # TTS request model + TtsResponse.java # TTS response model + vision/ + SarvamVisionService.java # Async job pipeline for document OCR + +core/src/main/java/com/google/adk/models/ + SarvamBaseLM.java # Lightweight BaseLlm for agent config integration +``` diff --git a/contrib/sarvam-ai/pom.xml b/contrib/sarvam-ai/pom.xml new file mode 100644 index 000000000..1b23411d3 --- /dev/null +++ b/contrib/sarvam-ai/pom.xml @@ -0,0 +1,133 @@ + + + + 4.0.0 + + + com.google.adk + google-adk-parent + 0.5.1-SNAPSHOT + ../../pom.xml + + + google-adk-sarvam-ai + Agent Development Kit - Sarvam AI + Sarvam AI integration for the Agent Development Kit. + + + + + com.google.adk + google-adk + ${project.version} + + + com.google.adk + google-adk-dev + ${project.version} + + + com.squareup.okhttp3 + okhttp + ${okhttp.version} + + + com.google.guava + guava + + + com.google.errorprone + error_prone_annotations + + + + + org.junit.jupiter + junit-jupiter-api + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.junit.jupiter + junit-jupiter-engine + test + + + com.google.truth + truth + test + + + org.assertj + assertj-core + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + + + com.squareup.okhttp3 + mockwebserver + ${okhttp.version} + test + + + + + + maven-surefire-plugin + 3.5.2 + + + me.fabriciorby + maven-surefire-junit5-tree-reporter + 0.1.0 + + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + + + + plain + + + **/*Test.java + + + ${project.basedir}/src/test/java + + + + + diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java new file mode 100644 index 000000000..4ced7f6c9 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java @@ -0,0 +1,292 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.BaseLlm; +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.models.sarvamai.chat.ChatRequest; +import com.google.adk.models.sarvamai.chat.ChatResponse; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.BackpressureStrategy; +import io.reactivex.rxjava3.core.Flowable; +import java.io.BufferedReader; +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Sarvam AI LLM integration for the Agent Development Kit. + * + *

Provides chat completion (blocking and streaming) via the Sarvam {@code sarvam-m} model using + * the OpenAI-compatible {@code /v1/chat/completions} endpoint. Authentication uses the {@code + * api-subscription-key} header per Sarvam API specification. + * + *

Follows the same architectural patterns as {@link com.google.adk.models.Gemini}, including + * Builder construction, immutable configuration, and RxJava-based streaming. + * + *

Usage: + * + *

{@code
+ * SarvamAi sarvam = SarvamAi.builder()
+ *     .modelName("sarvam-m")
+ *     .config(SarvamAiConfig.builder()
+ *         .apiKey("your-key")
+ *         .temperature(0.7)
+ *         .build())
+ *     .build();
+ * }
+ * + * @author Sandeep Belgavi + */ +public class SarvamAi extends BaseLlm { + + private static final Logger logger = LoggerFactory.getLogger(SarvamAi.class); + private static final MediaType JSON_MEDIA_TYPE = MediaType.get("application/json; charset=utf-8"); + + private final SarvamAiConfig config; + private final OkHttpClient httpClient; + private final ObjectMapper objectMapper; + + SarvamAi(String modelName, SarvamAiConfig config, OkHttpClient httpClient) { + super(modelName); + this.config = Objects.requireNonNull(config, "config must not be null"); + this.httpClient = Objects.requireNonNull(httpClient, "httpClient must not be null"); + this.objectMapper = new ObjectMapper(); + } + + public static Builder builder() { + return new Builder(); + } + + /** Returns the active configuration. */ + public SarvamAiConfig config() { + return config; + } + + /** Returns the shared OkHttpClient for subservices (STT, TTS, Vision). */ + OkHttpClient httpClient() { + return httpClient; + } + + /** Returns the shared ObjectMapper. */ + ObjectMapper objectMapper() { + return objectMapper; + } + + @Override + public Flowable generateContent(LlmRequest llmRequest, boolean stream) { + if (stream) { + return streamContent(llmRequest); + } + + return Flowable.fromCallable( + () -> { + ChatRequest chatRequest = ChatRequest.fromLlmRequest(model(), llmRequest, config, false); + String body = objectMapper.writeValueAsString(chatRequest); + logger.debug("Sending chat completion request to {}", config.chatEndpoint()); + logger.trace("Request body: {}", body); + + Request request = buildHttpRequest(config.chatEndpoint(), body); + + try (Response response = httpClient.newCall(request).execute()) { + handleErrorResponse(response); + String responseBody = response.body().string(); + logger.trace("Response body: {}", responseBody); + ChatResponse chatResponse = objectMapper.readValue(responseBody, ChatResponse.class); + return toLlmResponse(chatResponse); + } + }); + } + + private Flowable streamContent(LlmRequest llmRequest) { + return Flowable.create( + emitter -> { + try { + ChatRequest chatRequest = ChatRequest.fromLlmRequest(model(), llmRequest, config, true); + String body = objectMapper.writeValueAsString(chatRequest); + logger.debug("Sending streaming chat request to {}", config.chatEndpoint()); + + Request request = buildHttpRequest(config.chatEndpoint(), body); + + try (Response response = httpClient.newCall(request).execute()) { + handleErrorResponse(response); + + if (response.body() == null) { + emitter.onError(new SarvamAiException("Response body is null")); + return; + } + + try (BufferedReader reader = new BufferedReader(response.body().charStream())) { + String line; + while ((line = reader.readLine()) != null) { + if (emitter.isCancelled()) { + break; + } + if (!line.startsWith("data: ")) { + continue; + } + String data = line.substring(6).trim(); + if ("[DONE]".equals(data)) { + break; + } + try { + JsonNode chunk = objectMapper.readTree(data); + JsonNode choices = chunk.path("choices"); + if (choices.isArray() && !choices.isEmpty()) { + JsonNode delta = choices.get(0).path("delta"); + if (delta.has("content")) { + String textChunk = delta.get("content").asText(); + Content content = + Content.builder().role("model").parts(Part.fromText(textChunk)).build(); + emitter.onNext( + LlmResponse.builder().content(content).partial(true).build()); + } + } + } catch (Exception parseError) { + logger.trace("Skipping unparseable SSE line: {}", data); + } + } + } + emitter.onComplete(); + } + } catch (Exception e) { + if (!emitter.isCancelled()) { + emitter.onError(e); + } + } + }, + BackpressureStrategy.BUFFER); + } + + @Override + public BaseLlmConnection connect(LlmRequest llmRequest) { + logger.debug("Establishing Sarvam AI live connection"); + return new SarvamAiLlmConnection(this, llmRequest); + } + + Request buildHttpRequest(String url, String jsonBody) { + return new Request.Builder() + .url(url) + .addHeader("api-subscription-key", config.apiKey()) + .addHeader("Content-Type", "application/json") + .post(RequestBody.create(jsonBody, JSON_MEDIA_TYPE)) + .build(); + } + + void handleErrorResponse(Response response) throws IOException { + if (response.isSuccessful()) { + return; + } + String errorBody = response.body() != null ? response.body().string() : ""; + String errorCode = null; + String requestId = null; + String message = "Sarvam API error " + response.code(); + + try { + JsonNode errorJson = objectMapper.readTree(errorBody); + JsonNode error = errorJson.path("error"); + if (!error.isMissingNode()) { + message = error.path("message").asText(message); + errorCode = error.path("code").asText(null); + requestId = error.path("request_id").asText(null); + } + } catch (Exception ignored) { + // Use raw error body as message fallback + if (!errorBody.isEmpty()) { + message = message + ": " + errorBody; + } + } + + throw new SarvamAiException(message, response.code(), errorCode, requestId); + } + + private LlmResponse toLlmResponse(ChatResponse chatResponse) { + if (chatResponse.getChoices() == null || chatResponse.getChoices().isEmpty()) { + throw new SarvamAiException("Empty choices in response"); + } + var choice = chatResponse.getChoices().get(0); + var effectiveMsg = choice.effectiveMessage(); + if (effectiveMsg == null || effectiveMsg.getContent() == null) { + throw new SarvamAiException("No content in response choice"); + } + + Content content = + Content.builder().role("model").parts(Part.fromText(effectiveMsg.getContent())).build(); + return LlmResponse.builder().content(content).build(); + } + + /** Builder for {@link SarvamAi}. Mirrors the Gemini builder pattern. */ + public static final class Builder { + private String modelName; + private SarvamAiConfig config; + private OkHttpClient httpClient; + + private Builder() {} + + @CanIgnoreReturnValue + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + @CanIgnoreReturnValue + public Builder config(SarvamAiConfig config) { + this.config = config; + return this; + } + + /** + * Provides a custom OkHttpClient. If not set, a default client is created with retry + * interceptor and timeouts from the config. + */ + @CanIgnoreReturnValue + public Builder httpClient(OkHttpClient httpClient) { + this.httpClient = httpClient; + return this; + } + + public SarvamAi build() { + Objects.requireNonNull(modelName, "modelName must be set"); + Objects.requireNonNull(config, "config must be set"); + + OkHttpClient client = this.httpClient; + if (client == null) { + client = + new OkHttpClient.Builder() + .connectTimeout(config.connectTimeout().toMillis(), TimeUnit.MILLISECONDS) + .readTimeout(config.readTimeout().toMillis(), TimeUnit.MILLISECONDS) + .addInterceptor(new SarvamRetryInterceptor(config.maxRetries())) + .build(); + } + + return new SarvamAi(modelName, config, client); + } + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java new file mode 100644 index 000000000..061bf4818 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java @@ -0,0 +1,398 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; +import java.time.Duration; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalInt; + +/** + * Immutable configuration for Sarvam AI services. + * + *

Supports all Sarvam API parameters including chat completion, STT, TTS, and Vision. Uses the + * Builder pattern for safe, incremental construction with sensible defaults. + * + *

API key resolution order: explicit value > {@code SARVAM_API_KEY} environment variable. + * + * @author Sandeep Belgavi + */ +public final class SarvamAiConfig { + + public static final String DEFAULT_CHAT_ENDPOINT = "https://api.sarvam.ai/v1/chat/completions"; + public static final String DEFAULT_STT_ENDPOINT = "https://api.sarvam.ai/speech-to-text"; + public static final String DEFAULT_STT_WS_ENDPOINT = + "wss://api.sarvam.ai/speech-to-text/streaming"; + public static final String DEFAULT_TTS_ENDPOINT = "https://api.sarvam.ai/text-to-speech"; + public static final String DEFAULT_TTS_WS_ENDPOINT = + "wss://api.sarvam.ai/text-to-speech/streaming"; + public static final String DEFAULT_VISION_ENDPOINT = + "https://api.sarvam.ai/document-intelligence"; + public static final Duration DEFAULT_CONNECT_TIMEOUT = Duration.ofSeconds(30); + public static final Duration DEFAULT_READ_TIMEOUT = Duration.ofSeconds(120); + public static final int DEFAULT_MAX_RETRIES = 3; + + private final String apiKey; + private final String chatEndpoint; + private final String sttEndpoint; + private final String sttWsEndpoint; + private final String ttsEndpoint; + private final String ttsWsEndpoint; + private final String visionEndpoint; + private final Duration connectTimeout; + private final Duration readTimeout; + private final int maxRetries; + + // Chat-specific parameters + private final OptionalDouble temperature; + private final OptionalDouble topP; + private final OptionalInt maxTokens; + private final Optional reasoningEffort; + private final Optional wikiGrounding; + private final OptionalDouble frequencyPenalty; + private final OptionalDouble presencePenalty; + + // TTS-specific parameters + private final Optional ttsSpeaker; + private final Optional ttsModel; + private final OptionalDouble ttsPace; + private final OptionalInt ttsSampleRate; + + // STT-specific parameters + private final Optional sttModel; + private final Optional sttMode; + private final Optional sttLanguageCode; + + private SarvamAiConfig(Builder builder) { + String resolvedKey = builder.apiKey; + if (Strings.isNullOrEmpty(resolvedKey)) { + resolvedKey = System.getenv("SARVAM_API_KEY"); + } + Preconditions.checkArgument( + !Strings.isNullOrEmpty(resolvedKey), + "Sarvam API key is required. Set via builder or SARVAM_API_KEY environment variable."); + this.apiKey = resolvedKey; + + this.chatEndpoint = Objects.requireNonNullElse(builder.chatEndpoint, DEFAULT_CHAT_ENDPOINT); + this.sttEndpoint = Objects.requireNonNullElse(builder.sttEndpoint, DEFAULT_STT_ENDPOINT); + this.sttWsEndpoint = Objects.requireNonNullElse(builder.sttWsEndpoint, DEFAULT_STT_WS_ENDPOINT); + this.ttsEndpoint = Objects.requireNonNullElse(builder.ttsEndpoint, DEFAULT_TTS_ENDPOINT); + this.ttsWsEndpoint = Objects.requireNonNullElse(builder.ttsWsEndpoint, DEFAULT_TTS_WS_ENDPOINT); + this.visionEndpoint = + Objects.requireNonNullElse(builder.visionEndpoint, DEFAULT_VISION_ENDPOINT); + this.connectTimeout = + Objects.requireNonNullElse(builder.connectTimeout, DEFAULT_CONNECT_TIMEOUT); + this.readTimeout = Objects.requireNonNullElse(builder.readTimeout, DEFAULT_READ_TIMEOUT); + this.maxRetries = builder.maxRetries; + this.temperature = builder.temperature; + this.topP = builder.topP; + this.maxTokens = builder.maxTokens; + this.reasoningEffort = Optional.ofNullable(builder.reasoningEffort); + this.wikiGrounding = Optional.ofNullable(builder.wikiGrounding); + this.frequencyPenalty = builder.frequencyPenalty; + this.presencePenalty = builder.presencePenalty; + this.ttsSpeaker = Optional.ofNullable(builder.ttsSpeaker); + this.ttsModel = Optional.ofNullable(builder.ttsModel); + this.ttsPace = builder.ttsPace; + this.ttsSampleRate = builder.ttsSampleRate; + this.sttModel = Optional.ofNullable(builder.sttModel); + this.sttMode = Optional.ofNullable(builder.sttMode); + this.sttLanguageCode = Optional.ofNullable(builder.sttLanguageCode); + } + + public static Builder builder() { + return new Builder(); + } + + public String apiKey() { + return apiKey; + } + + public String chatEndpoint() { + return chatEndpoint; + } + + public String sttEndpoint() { + return sttEndpoint; + } + + public String sttWsEndpoint() { + return sttWsEndpoint; + } + + public String ttsEndpoint() { + return ttsEndpoint; + } + + public String ttsWsEndpoint() { + return ttsWsEndpoint; + } + + public String visionEndpoint() { + return visionEndpoint; + } + + public Duration connectTimeout() { + return connectTimeout; + } + + public Duration readTimeout() { + return readTimeout; + } + + public int maxRetries() { + return maxRetries; + } + + public OptionalDouble temperature() { + return temperature; + } + + public OptionalDouble topP() { + return topP; + } + + public OptionalInt maxTokens() { + return maxTokens; + } + + public Optional reasoningEffort() { + return reasoningEffort; + } + + public Optional wikiGrounding() { + return wikiGrounding; + } + + public OptionalDouble frequencyPenalty() { + return frequencyPenalty; + } + + public OptionalDouble presencePenalty() { + return presencePenalty; + } + + public Optional ttsSpeaker() { + return ttsSpeaker; + } + + public Optional ttsModel() { + return ttsModel; + } + + public OptionalDouble ttsPace() { + return ttsPace; + } + + public OptionalInt ttsSampleRate() { + return ttsSampleRate; + } + + public Optional sttModel() { + return sttModel; + } + + public Optional sttMode() { + return sttMode; + } + + public Optional sttLanguageCode() { + return sttLanguageCode; + } + + /** Builder for {@link SarvamAiConfig}. */ + public static final class Builder { + private String apiKey; + private String chatEndpoint; + private String sttEndpoint; + private String sttWsEndpoint; + private String ttsEndpoint; + private String ttsWsEndpoint; + private String visionEndpoint; + private Duration connectTimeout; + private Duration readTimeout; + private int maxRetries = DEFAULT_MAX_RETRIES; + private OptionalDouble temperature = OptionalDouble.empty(); + private OptionalDouble topP = OptionalDouble.empty(); + private OptionalInt maxTokens = OptionalInt.empty(); + private String reasoningEffort; + private Boolean wikiGrounding; + private OptionalDouble frequencyPenalty = OptionalDouble.empty(); + private OptionalDouble presencePenalty = OptionalDouble.empty(); + private String ttsSpeaker; + private String ttsModel; + private OptionalDouble ttsPace = OptionalDouble.empty(); + private OptionalInt ttsSampleRate = OptionalInt.empty(); + private String sttModel; + private String sttMode; + private String sttLanguageCode; + + private Builder() {} + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder chatEndpoint(String chatEndpoint) { + this.chatEndpoint = chatEndpoint; + return this; + } + + public Builder sttEndpoint(String sttEndpoint) { + this.sttEndpoint = sttEndpoint; + return this; + } + + public Builder sttWsEndpoint(String sttWsEndpoint) { + this.sttWsEndpoint = sttWsEndpoint; + return this; + } + + public Builder ttsEndpoint(String ttsEndpoint) { + this.ttsEndpoint = ttsEndpoint; + return this; + } + + public Builder ttsWsEndpoint(String ttsWsEndpoint) { + this.ttsWsEndpoint = ttsWsEndpoint; + return this; + } + + public Builder visionEndpoint(String visionEndpoint) { + this.visionEndpoint = visionEndpoint; + return this; + } + + public Builder connectTimeout(Duration connectTimeout) { + this.connectTimeout = connectTimeout; + return this; + } + + public Builder readTimeout(Duration readTimeout) { + this.readTimeout = readTimeout; + return this; + } + + public Builder maxRetries(int maxRetries) { + Preconditions.checkArgument(maxRetries >= 0, "maxRetries must be >= 0"); + this.maxRetries = maxRetries; + return this; + } + + public Builder temperature(double temperature) { + Preconditions.checkArgument( + temperature >= 0 && temperature <= 2, "temperature must be between 0 and 2"); + this.temperature = OptionalDouble.of(temperature); + return this; + } + + public Builder topP(double topP) { + Preconditions.checkArgument(topP >= 0 && topP <= 1, "topP must be between 0 and 1"); + this.topP = OptionalDouble.of(topP); + return this; + } + + public Builder maxTokens(int maxTokens) { + Preconditions.checkArgument(maxTokens > 0, "maxTokens must be > 0"); + this.maxTokens = OptionalInt.of(maxTokens); + return this; + } + + public Builder reasoningEffort(String reasoningEffort) { + Preconditions.checkArgument( + "low".equals(reasoningEffort) + || "medium".equals(reasoningEffort) + || "high".equals(reasoningEffort), + "reasoningEffort must be one of: low, medium, high"); + this.reasoningEffort = reasoningEffort; + return this; + } + + public Builder wikiGrounding(boolean wikiGrounding) { + this.wikiGrounding = wikiGrounding; + return this; + } + + public Builder frequencyPenalty(double frequencyPenalty) { + Preconditions.checkArgument( + frequencyPenalty >= -2 && frequencyPenalty <= 2, + "frequencyPenalty must be between -2 and 2"); + this.frequencyPenalty = OptionalDouble.of(frequencyPenalty); + return this; + } + + public Builder presencePenalty(double presencePenalty) { + Preconditions.checkArgument( + presencePenalty >= -2 && presencePenalty <= 2, + "presencePenalty must be between -2 and 2"); + this.presencePenalty = OptionalDouble.of(presencePenalty); + return this; + } + + public Builder ttsSpeaker(String ttsSpeaker) { + this.ttsSpeaker = ttsSpeaker; + return this; + } + + public Builder ttsModel(String ttsModel) { + this.ttsModel = ttsModel; + return this; + } + + public Builder ttsPace(double ttsPace) { + Preconditions.checkArgument( + ttsPace >= 0.5 && ttsPace <= 2.0, "ttsPace must be between 0.5 and 2.0"); + this.ttsPace = OptionalDouble.of(ttsPace); + return this; + } + + public Builder ttsSampleRate(int ttsSampleRate) { + this.ttsSampleRate = OptionalInt.of(ttsSampleRate); + return this; + } + + public Builder sttModel(String sttModel) { + this.sttModel = sttModel; + return this; + } + + public Builder sttMode(String sttMode) { + Preconditions.checkArgument( + "transcribe".equals(sttMode) + || "translate".equals(sttMode) + || "verbatim".equals(sttMode) + || "translit".equals(sttMode) + || "codemix".equals(sttMode), + "sttMode must be one of: transcribe, translate, verbatim, translit, codemix"); + this.sttMode = sttMode; + return this; + } + + public Builder sttLanguageCode(String sttLanguageCode) { + this.sttLanguageCode = sttLanguageCode; + return this; + } + + public SarvamAiConfig build() { + return new SarvamAiConfig(this); + } + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java new file mode 100644 index 000000000..7c52f76c5 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java @@ -0,0 +1,69 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import java.util.Optional; + +/** + * Domain exception for Sarvam AI API errors. Carries structured error information from the API + * response for programmatic error handling. + * + * @author Sandeep Belgavi + */ +public class SarvamAiException extends RuntimeException { + + private final int statusCode; + private final String errorCode; + private final String requestId; + + public SarvamAiException(String message, int statusCode, String errorCode, String requestId) { + super(message); + this.statusCode = statusCode; + this.errorCode = errorCode; + this.requestId = requestId; + } + + public SarvamAiException(String message, Throwable cause) { + super(message, cause); + this.statusCode = 0; + this.errorCode = null; + this.requestId = null; + } + + public SarvamAiException(String message) { + super(message); + this.statusCode = 0; + this.errorCode = null; + this.requestId = null; + } + + public int statusCode() { + return statusCode; + } + + public Optional errorCode() { + return Optional.ofNullable(errorCode); + } + + public Optional requestId() { + return Optional.ofNullable(requestId); + } + + public boolean isRetryable() { + return statusCode == 429 || statusCode == 503 || statusCode >= 500; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java new file mode 100644 index 000000000..e4348a0e7 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java @@ -0,0 +1,156 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.BackpressureStrategy; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.schedulers.Schedulers; +import io.reactivex.rxjava3.subjects.PublishSubject; +import java.util.ArrayList; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Live bidirectional connection to Sarvam AI, implementing multi-turn streaming conversations. + * + *

Maintains conversation history and streams responses token-by-token using SSE. Accumulates the + * full model response into history after each turn to support multi-turn context. + * + * @author Sandeep Belgavi + */ +final class SarvamAiLlmConnection implements BaseLlmConnection { + + private static final Logger logger = LoggerFactory.getLogger(SarvamAiLlmConnection.class); + + private final SarvamAi sarvamAi; + private final LlmRequest initialRequest; + private final List history; + private final PublishSubject responseSubject = PublishSubject.create(); + + SarvamAiLlmConnection(SarvamAi sarvamAi, LlmRequest llmRequest) { + this.sarvamAi = sarvamAi; + this.initialRequest = llmRequest; + this.history = new ArrayList<>(llmRequest.contents()); + } + + @Override + public Completable sendHistory(List newHistory) { + return Completable.fromAction( + () -> { + synchronized (history) { + history.clear(); + history.addAll(newHistory); + } + generateAndStream(); + }) + .subscribeOn(Schedulers.io()); + } + + @Override + public Completable sendContent(Content content) { + return Completable.fromAction( + () -> { + synchronized (history) { + history.add(content); + } + generateAndStream(); + }) + .subscribeOn(Schedulers.io()); + } + + @Override + public Completable sendRealtime(Blob blob) { + return Completable.error( + new UnsupportedOperationException( + "Realtime audio/video blobs are not supported on the chat connection. " + + "Use SarvamSttService for STT and SarvamTtsService for TTS.")); + } + + @Override + public Flowable receive() { + return responseSubject.toFlowable(BackpressureStrategy.BUFFER); + } + + @Override + public void close() { + responseSubject.onComplete(); + } + + @Override + public void close(Throwable throwable) { + responseSubject.onError(throwable); + } + + private void generateAndStream() { + List snapshot; + synchronized (history) { + snapshot = new ArrayList<>(history); + } + + LlmRequest.Builder turnBuilder = + LlmRequest.builder() + .contents(snapshot) + .appendTools(new ArrayList<>(initialRequest.tools().values())); + + initialRequest.config().ifPresent(turnBuilder::config); + turnBuilder.appendInstructions(initialRequest.getSystemInstructions()); + + LlmRequest turnRequest = turnBuilder.build(); + + StringBuilder fullText = new StringBuilder(); + + sarvamAi + .generateContent(turnRequest, true) + .subscribe( + response -> { + responseSubject.onNext(response); + response + .content() + .flatMap(Content::parts) + .ifPresent( + parts -> { + for (Part part : parts) { + part.text().ifPresent(fullText::append); + } + }); + }, + error -> { + logger.error("Error during Sarvam streaming turn", error); + responseSubject.onError(error); + }, + () -> { + if (fullText.length() > 0) { + Content responseContent = + Content.builder() + .role("model") + .parts(Part.fromText(fullText.toString())) + .build(); + synchronized (history) { + history.add(responseContent); + } + } + }); + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java new file mode 100644 index 000000000..8f0d9bda5 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java @@ -0,0 +1,105 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import java.io.IOException; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * OkHttp interceptor that implements exponential backoff with jitter for retryable Sarvam API + * errors (429 rate limit, 5xx server errors). + * + * @author Sandeep Belgavi + */ +final class SarvamRetryInterceptor implements Interceptor { + + private static final Logger logger = LoggerFactory.getLogger(SarvamRetryInterceptor.class); + private static final long BASE_DELAY_MS = 500; + private static final long MAX_DELAY_MS = 30_000; + + private final int maxRetries; + + SarvamRetryInterceptor(int maxRetries) { + this.maxRetries = maxRetries; + } + + @Override + public Response intercept(Chain chain) throws IOException { + Request request = chain.request(); + IOException lastException = null; + + for (int attempt = 0; attempt <= maxRetries; attempt++) { + try { + Response response = chain.proceed(request); + + if (response.isSuccessful() || !isRetryable(response.code()) || attempt == maxRetries) { + return response; + } + + response.close(); + long delay = calculateDelay(attempt); + logger.warn( + "Sarvam API returned {} for {}. Retrying in {}ms (attempt {}/{})", + response.code(), + request.url(), + delay, + attempt + 1, + maxRetries); + + sleep(delay); + } catch (IOException e) { + lastException = e; + if (attempt == maxRetries) { + break; + } + long delay = calculateDelay(attempt); + logger.warn( + "Sarvam API request failed: {}. Retrying in {}ms (attempt {}/{})", + e.getMessage(), + delay, + attempt + 1, + maxRetries); + sleep(delay); + } + } + + throw lastException != null ? lastException : new IOException("Request failed after retries"); + } + + private static boolean isRetryable(int statusCode) { + return statusCode == 429 || statusCode == 503 || statusCode >= 500; + } + + static long calculateDelay(int attempt) { + long delay = BASE_DELAY_MS * (1L << attempt); + delay = Math.min(delay, MAX_DELAY_MS); + long jitter = (long) (delay * 0.2 * Math.random()); + return delay + jitter; + } + + private static void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java new file mode 100644 index 000000000..0dd907812 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java @@ -0,0 +1,79 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A choice in the Sarvam AI chat completion response. Handles both non-streaming ({@code message}) + * and streaming ({@code delta}) response formats. + * + * @author Sandeep Belgavi + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ChatChoice { + + @JsonProperty("index") + private int index; + + @JsonProperty("message") + private ChatMessage message; + + @JsonProperty("delta") + private ChatMessage delta; + + @JsonProperty("finish_reason") + private String finishReason; + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + + public ChatMessage getMessage() { + return message; + } + + public void setMessage(ChatMessage message) { + this.message = message; + } + + public ChatMessage getDelta() { + return delta; + } + + public void setDelta(ChatMessage delta) { + this.delta = delta; + } + + public String getFinishReason() { + return finishReason; + } + + public void setFinishReason(String finishReason) { + this.finishReason = finishReason; + } + + /** Returns the effective message content, preferring delta for streaming responses. */ + public ChatMessage effectiveMessage() { + return delta != null ? delta : message; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java new file mode 100644 index 000000000..a820ac47e --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java @@ -0,0 +1,71 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A message in the Sarvam AI chat completion API (request or response). + * + * @author Sandeep Belgavi + */ +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +public final class ChatMessage { + + @JsonProperty("role") + private String role; + + @JsonProperty("content") + private String content; + + @JsonProperty("reasoning_content") + private String reasoningContent; + + public ChatMessage() {} + + public ChatMessage(String role, String content) { + this.role = role; + this.content = content; + } + + public String getRole() { + return role; + } + + public void setRole(String role) { + this.role = role; + } + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } + + public String getReasoningContent() { + return reasoningContent; + } + + public void setReasoningContent(String reasoningContent) { + this.reasoningContent = reasoningContent; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java new file mode 100644 index 000000000..3faefa2e9 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java @@ -0,0 +1,154 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.util.ArrayList; +import java.util.List; + +/** + * Request body for the Sarvam AI chat completions endpoint. Constructed from the ADK {@link + * LlmRequest} and {@link SarvamAiConfig}. + * + * @author Sandeep Belgavi + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public final class ChatRequest { + + @JsonProperty("model") + private String model; + + @JsonProperty("messages") + private List messages; + + @JsonProperty("stream") + private Boolean stream; + + @JsonProperty("temperature") + private Double temperature; + + @JsonProperty("top_p") + private Double topP; + + @JsonProperty("max_tokens") + private Integer maxTokens; + + @JsonProperty("reasoning_effort") + private String reasoningEffort; + + @JsonProperty("wiki_grounding") + private Boolean wikiGrounding; + + @JsonProperty("frequency_penalty") + private Double frequencyPenalty; + + @JsonProperty("presence_penalty") + private Double presencePenalty; + + @JsonProperty("n") + private Integer n; + + @JsonProperty("seed") + private Integer seed; + + @JsonProperty("stop") + private Object stop; + + public ChatRequest() {} + + /** + * Converts an ADK {@link LlmRequest} into a Sarvam-native {@link ChatRequest}, applying config + * defaults and mapping ADK roles to OpenAI-compatible roles. + */ + public static ChatRequest fromLlmRequest( + String modelName, LlmRequest llmRequest, SarvamAiConfig config, boolean stream) { + ChatRequest request = new ChatRequest(); + request.model = modelName; + request.stream = stream ? true : null; + request.messages = new ArrayList<>(); + + for (String instruction : llmRequest.getSystemInstructions()) { + request.messages.add(new ChatMessage("system", instruction)); + } + + for (Content content : llmRequest.contents()) { + String role = content.role().orElse("user"); + if ("model".equals(role)) { + role = "assistant"; + } + StringBuilder textBuilder = new StringBuilder(); + content + .parts() + .ifPresent( + parts -> { + for (Part part : parts) { + part.text().ifPresent(textBuilder::append); + } + }); + if (textBuilder.length() > 0) { + request.messages.add(new ChatMessage(role, textBuilder.toString())); + } + } + + config.temperature().ifPresent(v -> request.temperature = v); + config.topP().ifPresent(v -> request.topP = v); + config.maxTokens().ifPresent(v -> request.maxTokens = v); + config.reasoningEffort().ifPresent(v -> request.reasoningEffort = v); + config.wikiGrounding().ifPresent(v -> request.wikiGrounding = v); + config.frequencyPenalty().ifPresent(v -> request.frequencyPenalty = v); + config.presencePenalty().ifPresent(v -> request.presencePenalty = v); + + return request; + } + + public String getModel() { + return model; + } + + public List getMessages() { + return messages; + } + + public Boolean getStream() { + return stream; + } + + public Double getTemperature() { + return temperature; + } + + public Double getTopP() { + return topP; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public String getReasoningEffort() { + return reasoningEffort; + } + + public Boolean getWikiGrounding() { + return wikiGrounding; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java new file mode 100644 index 000000000..b3a215475 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java @@ -0,0 +1,97 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * Response from the Sarvam AI chat completions endpoint. Supports both non-streaming and streaming + * (SSE chunk) formats. + * + * @author Sandeep Belgavi + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ChatResponse { + + @JsonProperty("id") + private String id; + + @JsonProperty("object") + private String object; + + @JsonProperty("created") + private long created; + + @JsonProperty("model") + private String model; + + @JsonProperty("choices") + private List choices; + + @JsonProperty("usage") + private ChatUsage usage; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getObject() { + return object; + } + + public void setObject(String object) { + this.object = object; + } + + public long getCreated() { + return created; + } + + public void setCreated(long created) { + this.created = created; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public List getChoices() { + return choices; + } + + public void setChoices(List choices) { + this.choices = choices; + } + + public ChatUsage getUsage() { + return usage; + } + + public void setUsage(ChatUsage usage) { + this.usage = usage; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java new file mode 100644 index 000000000..11812cf1b --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java @@ -0,0 +1,62 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Token usage metadata from Sarvam AI API response. + * + * @author Sandeep Belgavi + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ChatUsage { + + @JsonProperty("prompt_tokens") + private int promptTokens; + + @JsonProperty("completion_tokens") + private int completionTokens; + + @JsonProperty("total_tokens") + private int totalTokens; + + public int getPromptTokens() { + return promptTokens; + } + + public void setPromptTokens(int promptTokens) { + this.promptTokens = promptTokens; + } + + public int getCompletionTokens() { + return completionTokens; + } + + public void setCompletionTokens(int completionTokens) { + this.completionTokens = completionTokens; + } + + public int getTotalTokens() { + return totalTokens; + } + + public void setTotalTokens(int totalTokens) { + this.totalTokens = totalTokens; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java new file mode 100644 index 000000000..0398f5ef7 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java @@ -0,0 +1,273 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.stt; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.adk.models.sarvamai.SarvamAiException; +import com.google.adk.transcription.ServiceHealth; +import com.google.adk.transcription.ServiceType; +import com.google.adk.transcription.TranscriptionConfig; +import com.google.adk.transcription.TranscriptionEvent; +import com.google.adk.transcription.TranscriptionException; +import com.google.adk.transcription.TranscriptionResult; +import com.google.adk.transcription.TranscriptionService; +import io.reactivex.rxjava3.core.BackpressureStrategy; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.Base64; +import java.util.Objects; +import okhttp3.MediaType; +import okhttp3.MultipartBody; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.WebSocket; +import okhttp3.WebSocketListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Sarvam AI Speech-to-Text service implementing the ADK {@link TranscriptionService} interface. + * + *

Supports three modes of operation: + * + *

    + *
  • REST synchronous ({@link #transcribe}): Single-shot transcription via {@code POST + * /speech-to-text} using model {@code saaras:v3}. + *
  • REST async ({@link #transcribeAsync}): Same as above, executed on an IO scheduler. + *
  • WebSocket streaming ({@link #transcribeStream}): Real-time streaming via WebSocket + * with VAD support, delivering partial and final transcription events. + *
+ * + * @author Sandeep Belgavi + */ +public final class SarvamSttService implements TranscriptionService { + + private static final Logger logger = LoggerFactory.getLogger(SarvamSttService.class); + + private final SarvamAiConfig config; + private final OkHttpClient httpClient; + private final ObjectMapper objectMapper; + + public SarvamSttService(SarvamAiConfig config, OkHttpClient httpClient) { + this.config = Objects.requireNonNull(config); + this.httpClient = Objects.requireNonNull(httpClient); + this.objectMapper = new ObjectMapper(); + } + + @Override + public TranscriptionResult transcribe(byte[] audioData, TranscriptionConfig requestConfig) + throws TranscriptionException { + try { + String sttModel = config.sttModel().orElse("saaras:v3"); + String mode = config.sttMode().orElse("transcribe"); + String languageCode = config.sttLanguageCode().orElse(requestConfig.getLanguage()); + + RequestBody fileBody = RequestBody.create(audioData, MediaType.parse("audio/wav")); + + MultipartBody.Builder bodyBuilder = + new MultipartBody.Builder() + .setType(MultipartBody.FORM) + .addFormDataPart("file", "audio.wav", fileBody) + .addFormDataPart("model", sttModel) + .addFormDataPart("mode", mode); + + if (languageCode != null && !"auto".equals(languageCode)) { + bodyBuilder.addFormDataPart("language_code", languageCode); + } + + Request request = + new Request.Builder() + .url(config.sttEndpoint()) + .addHeader("api-subscription-key", config.apiKey()) + .post(bodyBuilder.build()) + .build(); + + logger.debug( + "Sending STT request to {} with model={}, mode={}", config.sttEndpoint(), sttModel, mode); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : ""; + throw new TranscriptionException( + "STT request failed with status " + response.code() + ": " + errorBody); + } + + String responseBody = response.body().string(); + JsonNode root = objectMapper.readTree(responseBody); + String transcript = root.path("transcript").asText(""); + String detectedLang = root.path("language_code").asText(null); + + TranscriptionResult.Builder resultBuilder = + TranscriptionResult.builder().text(transcript).timestamp(System.currentTimeMillis()); + + if (detectedLang != null) { + resultBuilder.language(detectedLang); + } + + return resultBuilder.build(); + } + } catch (TranscriptionException e) { + throw e; + } catch (Exception e) { + throw new TranscriptionException("STT transcription failed", e); + } + } + + @Override + public Single transcribeAsync( + byte[] audioData, TranscriptionConfig requestConfig) { + return Single.fromCallable(() -> transcribe(audioData, requestConfig)) + .subscribeOn(Schedulers.io()); + } + + /** + * Streams audio data to Sarvam's WebSocket STT endpoint for real-time transcription. + * + *

Audio chunks are base64-encoded and sent as JSON frames. The server responds with transcript + * events including partial results and VAD signals (speech_start, speech_end). + */ + @Override + public Flowable transcribeStream( + Flowable audioStream, TranscriptionConfig requestConfig) { + + return Flowable.create( + emitter -> { + String sttModel = config.sttModel().orElse("saaras:v3"); + String mode = config.sttMode().orElse("transcribe"); + String languageCode = config.sttLanguageCode().orElse(requestConfig.getLanguage()); + + StringBuilder wsUrl = new StringBuilder(config.sttWsEndpoint()); + wsUrl.append("?model=").append(sttModel); + wsUrl.append("&mode=").append(mode); + if (languageCode != null && !"auto".equals(languageCode)) { + wsUrl.append("&language_code=").append(languageCode); + } + wsUrl.append("&high_vad_sensitivity=true"); + wsUrl.append("&vad_signals=true"); + + Request wsRequest = + new Request.Builder() + .url(wsUrl.toString()) + .addHeader("api-subscription-key", config.apiKey()) + .build(); + + logger.debug("Opening STT WebSocket to {}", wsUrl); + + WebSocket webSocket = + httpClient.newWebSocket( + wsRequest, + new WebSocketListener() { + @Override + public void onOpen(WebSocket ws, Response response) { + logger.debug("STT WebSocket connected"); + audioStream.subscribe( + chunk -> { + String base64Audio = Base64.getEncoder().encodeToString(chunk); + String frame = + String.format( + "{\"audio\":\"%s\",\"encoding\":\"audio/wav\",\"sample_rate\":16000}", + base64Audio); + ws.send(frame); + }, + error -> { + logger.error("Audio stream error", error); + ws.close(1000, "Audio stream error"); + }, + () -> { + logger.debug("Audio stream completed, closing WebSocket"); + ws.close(1000, "Stream complete"); + }); + } + + @Override + public void onMessage(WebSocket ws, String text) { + try { + JsonNode node = objectMapper.readTree(text); + String type = node.path("type").asText(""); + + switch (type) { + case "transcript": + case "translation": + String transcript = node.path("text").asText(""); + emitter.onNext( + TranscriptionEvent.builder() + .text(transcript) + .finished(true) + .timestamp(System.currentTimeMillis()) + .build()); + break; + case "speech_start": + logger.trace("VAD: speech started"); + break; + case "speech_end": + logger.trace("VAD: speech ended"); + break; + default: + logger.trace("Received STT WS message type: {}", type); + } + } catch (Exception e) { + logger.warn("Failed to parse STT WS message: {}", text, e); + } + } + + @Override + public void onClosing(WebSocket ws, int code, String reason) { + logger.debug("STT WebSocket closing: {} {}", code, reason); + ws.close(code, reason); + } + + @Override + public void onClosed(WebSocket ws, int code, String reason) { + logger.debug("STT WebSocket closed: {} {}", code, reason); + emitter.onComplete(); + } + + @Override + public void onFailure(WebSocket ws, Throwable t, Response response) { + logger.error("STT WebSocket failure", t); + if (!emitter.isCancelled()) { + emitter.onError( + new SarvamAiException("STT WebSocket connection failed", t)); + } + } + }); + + emitter.setCancellable(() -> webSocket.close(1000, "Cancelled")); + }, + BackpressureStrategy.BUFFER); + } + + @Override + public boolean isAvailable() { + return config.apiKey() != null && !config.apiKey().isEmpty(); + } + + @Override + public ServiceType getServiceType() { + return ServiceType.SARVAM; + } + + @Override + public ServiceHealth getHealth() { + return ServiceHealth.builder().available(isAvailable()).serviceType(ServiceType.SARVAM).build(); + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java new file mode 100644 index 000000000..414a8b5b6 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java @@ -0,0 +1,240 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.tts; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.adk.models.sarvamai.SarvamAiException; +import io.reactivex.rxjava3.core.BackpressureStrategy; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.Base64; +import java.util.Objects; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.WebSocket; +import okhttp3.WebSocketListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Sarvam AI Text-to-Speech service with both REST and WebSocket streaming support. + * + *

REST mode ({@link #synthesize}): Sends text and returns the complete audio as a byte array + * (decoded from base64). Uses the Bulbul v3 model with 30+ speaker voices. + * + *

WebSocket streaming mode ({@link #synthesizeStream}): Opens a persistent WebSocket connection + * for progressive audio chunk delivery with low latency. Audio chunks are emitted as they are + * synthesized, enabling real-time playback. + * + * @author Sandeep Belgavi + */ +public final class SarvamTtsService { + + private static final Logger logger = LoggerFactory.getLogger(SarvamTtsService.class); + private static final MediaType JSON_MEDIA_TYPE = MediaType.get("application/json; charset=utf-8"); + + private final SarvamAiConfig config; + private final OkHttpClient httpClient; + private final ObjectMapper objectMapper; + + public SarvamTtsService(SarvamAiConfig config, OkHttpClient httpClient) { + this.config = Objects.requireNonNull(config); + this.httpClient = Objects.requireNonNull(httpClient); + this.objectMapper = new ObjectMapper(); + } + + /** + * Synthesizes speech from text synchronously via the REST endpoint. + * + * @param text the text to convert to speech (max 2500 chars for bulbul:v3) + * @param targetLanguageCode BCP-47 language code (e.g., "en-IN", "hi-IN") + * @return decoded audio bytes (WAV format by default) + */ + public byte[] synthesize(String text, String targetLanguageCode) { + Objects.requireNonNull(text, "text must not be null"); + Objects.requireNonNull(targetLanguageCode, "targetLanguageCode must not be null"); + + String model = config.ttsModel().orElse("bulbul:v3"); + String speaker = config.ttsSpeaker().orElse("shubh"); + Double pace = config.ttsPace().isPresent() ? config.ttsPace().getAsDouble() : null; + Integer sampleRate = + config.ttsSampleRate().isPresent() ? config.ttsSampleRate().getAsInt() : null; + + TtsRequest ttsRequest = + new TtsRequest(text, targetLanguageCode, model, speaker, pace, sampleRate); + + try { + String body = objectMapper.writeValueAsString(ttsRequest); + + Request request = + new Request.Builder() + .url(config.ttsEndpoint()) + .addHeader("api-subscription-key", config.apiKey()) + .addHeader("Content-Type", "application/json") + .post(RequestBody.create(body, JSON_MEDIA_TYPE)) + .build(); + + logger.debug( + "Sending TTS request to {} with model={}, speaker={}", + config.ttsEndpoint(), + model, + speaker); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : ""; + throw new SarvamAiException( + "TTS request failed: " + response.code() + " " + errorBody, + response.code(), + null, + null); + } + + TtsResponse ttsResponse = + objectMapper.readValue(response.body().string(), TtsResponse.class); + if (ttsResponse.getAudios() == null || ttsResponse.getAudios().isEmpty()) { + throw new SarvamAiException("TTS response contained no audio data"); + } + + String combinedBase64 = String.join("", ttsResponse.getAudios()); + return Base64.getDecoder().decode(combinedBase64); + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("TTS synthesis failed", e); + } + } + + /** Async version of {@link #synthesize}. */ + public Single synthesizeAsync(String text, String targetLanguageCode) { + return Single.fromCallable(() -> synthesize(text, targetLanguageCode)) + .subscribeOn(Schedulers.io()); + } + + /** + * Streams TTS audio via WebSocket for low-latency, progressive playback. + * + *

Opens a WebSocket to Sarvam's streaming TTS endpoint, sends config + text, and emits decoded + * audio chunks as they arrive. Each chunk is a raw audio byte array ready for playback. + * + * @param text the text to synthesize + * @param targetLanguageCode BCP-47 language code + * @return a Flowable of audio byte[] chunks + */ + public Flowable synthesizeStream(String text, String targetLanguageCode) { + Objects.requireNonNull(text, "text must not be null"); + Objects.requireNonNull(targetLanguageCode, "targetLanguageCode must not be null"); + + return Flowable.create( + emitter -> { + String model = config.ttsModel().orElse("bulbul:v3"); + String speaker = config.ttsSpeaker().orElse("shubh"); + + String wsUrl = config.ttsWsEndpoint() + "?model=" + model; + + Request wsRequest = + new Request.Builder() + .url(wsUrl) + .addHeader("api-subscription-key", config.apiKey()) + .build(); + + logger.debug("Opening TTS WebSocket to {}", wsUrl); + + WebSocket webSocket = + httpClient.newWebSocket( + wsRequest, + new WebSocketListener() { + @Override + public void onOpen(WebSocket ws, Response response) { + logger.debug("TTS WebSocket connected"); + + String configMsg = + String.format( + "{\"type\":\"config\",\"data\":{\"speaker\":\"%s\"," + + "\"target_language_code\":\"%s\"}}", + speaker, targetLanguageCode); + ws.send(configMsg); + + String textMsg = + String.format( + "{\"type\":\"text\",\"data\":{\"text\":\"%s\"}}", + text.replace("\"", "\\\"")); + ws.send(textMsg); + + ws.send("{\"type\":\"flush\"}"); + } + + @Override + public void onMessage(WebSocket ws, String messageText) { + try { + JsonNode node = objectMapper.readTree(messageText); + String type = node.path("type").asText(""); + + if ("audio".equals(type)) { + String audioBase64 = node.path("data").path("audio").asText(""); + if (!audioBase64.isEmpty()) { + byte[] audioChunk = Base64.getDecoder().decode(audioBase64); + emitter.onNext(audioChunk); + } + } else if ("event".equals(type)) { + String eventType = node.path("data").path("event_type").asText(""); + if ("final".equals(eventType)) { + ws.close(1000, "Synthesis complete"); + } + } + } catch (Exception e) { + logger.warn("Failed to parse TTS WS message", e); + } + } + + @Override + public void onClosing(WebSocket ws, int code, String reason) { + ws.close(code, reason); + } + + @Override + public void onClosed(WebSocket ws, int code, String reason) { + logger.debug("TTS WebSocket closed: {} {}", code, reason); + emitter.onComplete(); + } + + @Override + public void onFailure(WebSocket ws, Throwable t, Response response) { + logger.error("TTS WebSocket failure", t); + if (!emitter.isCancelled()) { + emitter.onError( + new SarvamAiException("TTS WebSocket connection failed", t)); + } + } + }); + + emitter.setCancellable(() -> webSocket.close(1000, "Cancelled")); + }, + BackpressureStrategy.BUFFER); + } + + public boolean isAvailable() { + return config.apiKey() != null && !config.apiKey().isEmpty(); + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java new file mode 100644 index 000000000..b387cec08 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.tts; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Request body for the Sarvam AI text-to-speech REST endpoint. + * + * @author Sandeep Belgavi + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public final class TtsRequest { + + @JsonProperty("text") + private String text; + + @JsonProperty("target_language_code") + private String targetLanguageCode; + + @JsonProperty("model") + private String model; + + @JsonProperty("speaker") + private String speaker; + + @JsonProperty("pace") + private Double pace; + + @JsonProperty("speech_sample_rate") + private Integer speechSampleRate; + + public TtsRequest() {} + + public TtsRequest( + String text, + String targetLanguageCode, + String model, + String speaker, + Double pace, + Integer speechSampleRate) { + this.text = text; + this.targetLanguageCode = targetLanguageCode; + this.model = model; + this.speaker = speaker; + this.pace = pace; + this.speechSampleRate = speechSampleRate; + } + + public String getText() { + return text; + } + + public String getTargetLanguageCode() { + return targetLanguageCode; + } + + public String getModel() { + return model; + } + + public String getSpeaker() { + return speaker; + } + + public Double getPace() { + return pace; + } + + public Integer getSpeechSampleRate() { + return speechSampleRate; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java new file mode 100644 index 000000000..3712bdad6 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java @@ -0,0 +1,53 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.tts; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * Response from the Sarvam AI text-to-speech REST endpoint. + * + * @author Sandeep Belgavi + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public final class TtsResponse { + + @JsonProperty("request_id") + private String requestId; + + @JsonProperty("audios") + private List audios; + + public String getRequestId() { + return requestId; + } + + public void setRequestId(String requestId) { + this.requestId = requestId; + } + + /** Returns base64-encoded audio strings. Each element corresponds to an input text segment. */ + public List getAudios() { + return audios; + } + + public void setAudios(List audios) { + this.audios = audios; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java new file mode 100644 index 000000000..420d491ce --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java @@ -0,0 +1,296 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.vision; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.adk.models.sarvamai.SarvamAiException; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Objects; +import java.util.Optional; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Sarvam Vision Document Intelligence service. + * + *

Powered by the Sarvam Vision 3B VLM for extracting structured text from documents across 23 + * languages (22 Indian + English). Supports PDF, PNG, JPG, and ZIP inputs with HTML or Markdown + * output. + * + *

The workflow follows Sarvam's async job pattern: + * + *

    + *
  1. {@link #createJob} - Initialize a document processing job + *
  2. {@link #uploadDocument} - Upload the document to the job's presigned URL + *
  3. {@link #startJob} - Begin processing + *
  4. {@link #getJobStatus} - Poll for completion + *
  5. {@link #downloadResults} - Retrieve the processed output + *
+ * + * @author Sandeep Belgavi + */ +public final class SarvamVisionService { + + private static final Logger logger = LoggerFactory.getLogger(SarvamVisionService.class); + private static final MediaType JSON_MEDIA_TYPE = MediaType.get("application/json; charset=utf-8"); + + private final SarvamAiConfig config; + private final OkHttpClient httpClient; + private final ObjectMapper objectMapper; + + public SarvamVisionService(SarvamAiConfig config, OkHttpClient httpClient) { + this.config = Objects.requireNonNull(config); + this.httpClient = Objects.requireNonNull(httpClient); + this.objectMapper = new ObjectMapper(); + } + + /** Result of a job creation request. */ + public record JobInfo(String jobId, String uploadUrl) {} + + /** Current status of a document intelligence job. */ + public record JobStatus(String jobId, String state, Optional downloadUrl) {} + + /** + * Creates a new document intelligence job. + * + * @param languageCode BCP-47 code (e.g., "hi-IN", "en-IN") + * @param outputFormat "html" or "md" + * @return job info with ID and upload URL + */ + public JobInfo createJob(String languageCode, String outputFormat) { + Objects.requireNonNull(languageCode); + Objects.requireNonNull(outputFormat); + + try { + String body = + objectMapper.writeValueAsString( + new java.util.HashMap() { + { + put("language", languageCode); + put("output_format", outputFormat); + } + }); + + Request request = + new Request.Builder() + .url(config.visionEndpoint() + "/create") + .addHeader("api-subscription-key", config.apiKey()) + .addHeader("Content-Type", "application/json") + .post(RequestBody.create(body, JSON_MEDIA_TYPE)) + .build(); + + logger.debug("Creating vision job: lang={}, format={}", languageCode, outputFormat); + + try (Response response = httpClient.newCall(request).execute()) { + ensureSuccess(response, "Create job"); + JsonNode root = objectMapper.readTree(response.body().string()); + String jobId = root.path("job_id").asText(); + String uploadUrl = root.path("upload_url").asText(); + return new JobInfo(jobId, uploadUrl); + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("Failed to create vision job", e); + } + } + + /** + * Uploads a document to the presigned upload URL. + * + * @param uploadUrl the presigned URL from {@link #createJob} + * @param filePath path to the document file + */ + public void uploadDocument(String uploadUrl, Path filePath) { + Objects.requireNonNull(uploadUrl); + Objects.requireNonNull(filePath); + + try { + byte[] fileBytes = Files.readAllBytes(filePath); + String contentType = Files.probeContentType(filePath); + if (contentType == null) { + contentType = "application/octet-stream"; + } + + Request request = + new Request.Builder() + .url(uploadUrl) + .put(RequestBody.create(fileBytes, MediaType.parse(contentType))) + .build(); + + logger.debug("Uploading document {} ({} bytes)", filePath, fileBytes.length); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + throw new SarvamAiException( + "Document upload failed: " + response.code(), response.code(), null, null); + } + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("Failed to upload document", e); + } + } + + /** Starts processing a previously created and uploaded job. */ + public void startJob(String jobId) { + Objects.requireNonNull(jobId); + + try { + String body = objectMapper.writeValueAsString(java.util.Map.of("job_id", jobId)); + Request request = + new Request.Builder() + .url(config.visionEndpoint() + "/start") + .addHeader("api-subscription-key", config.apiKey()) + .addHeader("Content-Type", "application/json") + .post(RequestBody.create(body, JSON_MEDIA_TYPE)) + .build(); + + logger.debug("Starting vision job {}", jobId); + + try (Response response = httpClient.newCall(request).execute()) { + ensureSuccess(response, "Start job"); + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("Failed to start vision job", e); + } + } + + /** Gets the current status of a document processing job. */ + public JobStatus getJobStatus(String jobId) { + Objects.requireNonNull(jobId); + + try { + Request request = + new Request.Builder() + .url(config.visionEndpoint() + "/status?job_id=" + jobId) + .addHeader("api-subscription-key", config.apiKey()) + .get() + .build(); + + try (Response response = httpClient.newCall(request).execute()) { + ensureSuccess(response, "Get job status"); + JsonNode root = objectMapper.readTree(response.body().string()); + String state = root.path("job_state").asText("unknown"); + String downloadUrl = root.path("download_url").asText(null); + return new JobStatus(jobId, state, Optional.ofNullable(downloadUrl)); + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("Failed to get job status", e); + } + } + + /** + * Downloads the processed results. + * + * @param downloadUrl the URL from {@link JobStatus#downloadUrl()} + * @return the result bytes (typically a ZIP file containing HTML/Markdown) + */ + public byte[] downloadResults(String downloadUrl) { + Objects.requireNonNull(downloadUrl); + + try { + Request request = new Request.Builder().url(downloadUrl).get().build(); + + logger.debug("Downloading vision results from {}", downloadUrl); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + throw new SarvamAiException( + "Download failed: " + response.code(), response.code(), null, null); + } + return response.body().bytes(); + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("Failed to download results", e); + } + } + + /** + * Convenience method: runs the full pipeline (create -> upload -> start -> poll -> download) + * asynchronously. + */ + public Single processDocument(Path filePath, String languageCode, String outputFormat) { + return Single.create( + emitter -> { + try { + JobInfo job = createJob(languageCode, outputFormat); + uploadDocument(job.uploadUrl(), filePath); + startJob(job.jobId()); + + // Poll with backoff + int maxPolls = 60; + long pollIntervalMs = 2000; + for (int i = 0; i < maxPolls; i++) { + Thread.sleep(pollIntervalMs); + JobStatus status = getJobStatus(job.jobId()); + + if ("completed".equalsIgnoreCase(status.state())) { + if (status.downloadUrl().isPresent()) { + byte[] result = downloadResults(status.downloadUrl().get()); + emitter.onSuccess(result); + return; + } + emitter.onError( + new SarvamAiException("Job completed but no download URL provided")); + return; + } else if ("failed".equalsIgnoreCase(status.state())) { + emitter.onError(new SarvamAiException("Vision job failed: " + job.jobId())); + return; + } + + // Adaptive backoff + pollIntervalMs = Math.min(pollIntervalMs * 2, 10_000); + } + emitter.onError(new SarvamAiException("Vision job timed out: " + job.jobId())); + } catch (Exception e) { + emitter.onError(e); + } + }) + .subscribeOn(Schedulers.io()); + } + + public boolean isAvailable() { + return config.apiKey() != null && !config.apiKey().isEmpty(); + } + + private void ensureSuccess(Response response, String operation) throws IOException { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : ""; + throw new SarvamAiException( + operation + " failed: " + response.code() + " " + errorBody, response.code(), null, null); + } + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java new file mode 100644 index 000000000..b1a5243a0 --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java @@ -0,0 +1,132 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +/** @author Sandeep Belgavi */ +class SarvamAiConfigTest { + + @Test + void builder_withApiKey_succeeds() { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("test-key").build(); + assertThat(config.apiKey()).isEqualTo("test-key"); + } + + @Test + void builder_withoutApiKey_throwsIfEnvNotSet() { + if (System.getenv("SARVAM_API_KEY") != null) { + return; + } + assertThrows(IllegalArgumentException.class, () -> SarvamAiConfig.builder().build()); + } + + @Test + void builder_setsDefaultEndpoints() { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").build(); + assertThat(config.chatEndpoint()).isEqualTo(SarvamAiConfig.DEFAULT_CHAT_ENDPOINT); + assertThat(config.sttEndpoint()).isEqualTo(SarvamAiConfig.DEFAULT_STT_ENDPOINT); + assertThat(config.ttsEndpoint()).isEqualTo(SarvamAiConfig.DEFAULT_TTS_ENDPOINT); + assertThat(config.visionEndpoint()).isEqualTo(SarvamAiConfig.DEFAULT_VISION_ENDPOINT); + } + + @Test + void builder_customEndpoints() { + SarvamAiConfig config = + SarvamAiConfig.builder() + .apiKey("key") + .chatEndpoint("http://custom/chat") + .sttEndpoint("http://custom/stt") + .ttsEndpoint("http://custom/tts") + .build(); + + assertThat(config.chatEndpoint()).isEqualTo("http://custom/chat"); + assertThat(config.sttEndpoint()).isEqualTo("http://custom/stt"); + assertThat(config.ttsEndpoint()).isEqualTo("http://custom/tts"); + } + + @Test + void builder_temperatureValidation() { + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").temperature(3.0).build()); + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").temperature(-1.0).build()); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").temperature(0.7).build(); + assertThat(config.temperature().getAsDouble()).isWithin(0.001).of(0.7); + } + + @Test + void builder_reasoningEffortValidation() { + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").reasoningEffort("invalid").build()); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").reasoningEffort("high").build(); + assertThat(config.reasoningEffort()).hasValue("high"); + } + + @Test + void builder_sttModeValidation() { + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").sttMode("invalid").build()); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").sttMode("translate").build(); + assertThat(config.sttMode()).hasValue("translate"); + } + + @Test + void builder_ttsPaceValidation() { + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").ttsPace(0.1).build()); + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").ttsPace(3.0).build()); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").ttsPace(1.5).build(); + assertThat(config.ttsPace().getAsDouble()).isWithin(0.001).of(1.5); + } + + @Test + void builder_maxRetriesDefault() { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").build(); + assertThat(config.maxRetries()).isEqualTo(SarvamAiConfig.DEFAULT_MAX_RETRIES); + } + + @Test + void builder_wikiGrounding() { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").wikiGrounding(true).build(); + assertThat(config.wikiGrounding()).hasValue(true); + } + + @Test + void builder_chatParametersOptionalByDefault() { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").build(); + assertThat(config.temperature().isEmpty()).isTrue(); + assertThat(config.topP().isEmpty()).isTrue(); + assertThat(config.maxTokens().isEmpty()).isTrue(); + assertThat(config.reasoningEffort().isEmpty()).isTrue(); + assertThat(config.wikiGrounding().isEmpty()).isTrue(); + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java new file mode 100644 index 000000000..9fb79c8f6 --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java @@ -0,0 +1,213 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.subscribers.TestSubscriber; +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.TimeUnit; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** @author Sandeep Belgavi */ +class SarvamAiTest { + + private MockWebServer server; + private SarvamAi sarvamAi; + + @BeforeEach + void setUp() throws IOException { + server = new MockWebServer(); + server.start(); + + SarvamAiConfig config = + SarvamAiConfig.builder() + .apiKey("test-api-key") + .chatEndpoint(server.url("/v1/chat/completions").toString()) + .build(); + + sarvamAi = + SarvamAi.builder() + .modelName("sarvam-m") + .config(config) + .httpClient(new OkHttpClient()) + .build(); + } + + @AfterEach + void tearDown() throws IOException { + server.shutdown(); + } + + @Test + void generateContent_nonStreaming_returnsContent() { + String jsonResponse = + "{\"id\":\"chatcmpl-abc\",\"object\":\"chat.completion\",\"created\":1699000000," + + "\"model\":\"sarvam-m\",\"choices\":[{\"index\":0," + + "\"message\":{\"role\":\"assistant\",\"content\":\"Hello world\"}," + + "\"finish_reason\":\"stop\"}]," + + "\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}"; + server.enqueue(new MockResponse().setBody(jsonResponse)); + + LlmRequest request = buildUserRequest("Hi"); + TestSubscriber subscriber = sarvamAi.generateContent(request, false).test(); + + subscriber.awaitDone(5, TimeUnit.SECONDS); + subscriber.assertNoErrors(); + subscriber.assertValueCount(1); + + LlmResponse response = subscriber.values().get(0); + assertThat(response.content().flatMap(Content::parts).get().get(0).text().get()) + .isEqualTo("Hello world"); + } + + @Test + void generateContent_streaming_returnsChunks() { + String chunk1 = "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"; + String chunk2 = "data: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\n"; + String done = "data: [DONE]\n\n"; + + server.enqueue(new MockResponse().setBody(chunk1 + chunk2 + done)); + + LlmRequest request = buildUserRequest("Hi"); + TestSubscriber subscriber = sarvamAi.generateContent(request, true).test(); + + subscriber.awaitDone(5, TimeUnit.SECONDS); + subscriber.assertNoErrors(); + subscriber.assertValueCount(2); + + assertThat( + subscriber.values().get(0).content().flatMap(Content::parts).get().get(0).text().get()) + .isEqualTo("Hello"); + assertThat( + subscriber.values().get(1).content().flatMap(Content::parts).get().get(0).text().get()) + .isEqualTo(" world"); + } + + @Test + void generateContent_streamingChunksAreMarkedPartial() { + server.enqueue( + new MockResponse() + .setBody( + "data: {\"choices\":[{\"delta\":{\"content\":\"test\"}}]}\n\ndata: [DONE]\n\n")); + + LlmRequest request = buildUserRequest("Hi"); + TestSubscriber subscriber = sarvamAi.generateContent(request, true).test(); + + subscriber.awaitDone(5, TimeUnit.SECONDS); + subscriber.assertNoErrors(); + LlmResponse response = subscriber.values().get(0); + assertThat(response.partial().orElse(false)).isTrue(); + } + + @Test + void generateContent_serverError_propagatesException() { + server.enqueue( + new MockResponse() + .setResponseCode(500) + .setBody( + "{\"error\":{\"message\":\"Internal error\",\"code\":\"internal_server_error\"}}")); + + LlmRequest request = buildUserRequest("Hi"); + TestSubscriber subscriber = sarvamAi.generateContent(request, false).test(); + + subscriber.awaitDone(5, TimeUnit.SECONDS); + subscriber.assertError(SarvamAiException.class); + } + + @Test + void generateContent_usesCorrectAuthHeader() throws InterruptedException { + server.enqueue( + new MockResponse() + .setBody("{\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"ok\"}}]}")); + + sarvamAi.generateContent(buildUserRequest("Hi"), false).blockingSubscribe(); + + RecordedRequest recorded = server.takeRequest(5, TimeUnit.SECONDS); + assertThat(recorded).isNotNull(); + assertThat(recorded.getHeader("api-subscription-key")).isEqualTo("test-api-key"); + } + + @Test + void generateContent_setsStreamFlagInBody() throws InterruptedException { + String chunk = "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\ndata: [DONE]\n\n"; + server.enqueue(new MockResponse().setBody(chunk)); + + sarvamAi.generateContent(buildUserRequest("Hello"), true).blockingSubscribe(); + + RecordedRequest recorded = server.takeRequest(5, TimeUnit.SECONDS); + assertThat(recorded).isNotNull(); + String body = recorded.getBody().readUtf8(); + assertThat(body).contains("\"stream\":true"); + } + + @Test + void generateContent_mapsModelRoleToAssistant() throws InterruptedException { + server.enqueue( + new MockResponse() + .setBody("{\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"ok\"}}]}")); + + LlmRequest request = + LlmRequest.builder() + .contents( + java.util.List.of( + Content.builder().role("user").parts(Part.fromText("Hi")).build(), + Content.builder().role("model").parts(Part.fromText("Hello")).build(), + Content.builder().role("user").parts(Part.fromText("How?")).build())) + .build(); + + sarvamAi.generateContent(request, false).blockingSubscribe(); + + RecordedRequest recorded = server.takeRequest(5, TimeUnit.SECONDS); + String body = recorded.getBody().readUtf8(); + assertThat(body).contains("\"role\":\"assistant\""); + assertThat(body).doesNotContain("\"role\":\"model\""); + } + + @Test + void builder_requiresModelName() { + assertThrows( + NullPointerException.class, + () -> SarvamAi.builder().config(SarvamAiConfig.builder().apiKey("key").build()).build()); + } + + @Test + void builder_requiresConfig() { + assertThrows( + NullPointerException.class, () -> SarvamAi.builder().modelName("sarvam-m").build()); + } + + private LlmRequest buildUserRequest(String text) { + return LlmRequest.builder() + .contents( + Collections.singletonList( + Content.builder().role("user").parts(Part.fromText(text)).build())) + .build(); + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java new file mode 100644 index 000000000..f62907cde --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java @@ -0,0 +1,47 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.jupiter.api.Test; + +/** @author Sandeep Belgavi */ +class SarvamRetryInterceptorTest { + + @Test + void calculateDelay_exponentiallyIncreases() { + long delay0 = SarvamRetryInterceptor.calculateDelay(0); + long delay1 = SarvamRetryInterceptor.calculateDelay(1); + long delay2 = SarvamRetryInterceptor.calculateDelay(2); + + assertThat(delay0).isAtLeast(500); + assertThat(delay0).isAtMost(700); + + assertThat(delay1).isAtLeast(1000); + assertThat(delay1).isAtMost(1400); + + assertThat(delay2).isAtLeast(2000); + assertThat(delay2).isAtMost(2800); + } + + @Test + void calculateDelay_respectsMaxCap() { + long delay10 = SarvamRetryInterceptor.calculateDelay(10); + assertThat(delay10).isAtMost(36_000); + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java new file mode 100644 index 000000000..aa39eb743 --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import static com.google.common.truth.Truth.assertThat; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** @author Sandeep Belgavi */ +class ChatRequestTest { + + private final ObjectMapper objectMapper = new ObjectMapper(); + + @Test + void fromLlmRequest_mapsUserAndAssistantMessages() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .contents( + List.of( + Content.builder().role("user").parts(Part.fromText("Hello")).build(), + Content.builder().role("model").parts(Part.fromText("Hi there")).build(), + Content.builder().role("user").parts(Part.fromText("How?")).build())) + .build(); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").temperature(0.5).build(); + ChatRequest request = ChatRequest.fromLlmRequest("sarvam-m", llmRequest, config, false); + + assertThat(request.getModel()).isEqualTo("sarvam-m"); + assertThat(request.getMessages()).hasSize(3); + assertThat(request.getMessages().get(0).getRole()).isEqualTo("user"); + assertThat(request.getMessages().get(1).getRole()).isEqualTo("assistant"); + assertThat(request.getMessages().get(2).getRole()).isEqualTo("user"); + assertThat(request.getTemperature()).isWithin(0.001).of(0.5); + assertThat(request.getStream()).isNull(); + } + + @Test + void fromLlmRequest_includesSystemInstructions() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.builder().role("user").parts(Part.fromText("Hello")).build())) + .config( + GenerateContentConfig.builder() + .systemInstruction( + Content.builder() + .parts(Part.fromText("You are a helpful assistant")) + .build()) + .build()) + .build(); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").build(); + ChatRequest request = ChatRequest.fromLlmRequest("sarvam-m", llmRequest, config, true); + + assertThat(request.getMessages().get(0).getRole()).isEqualTo("system"); + assertThat(request.getMessages().get(0).getContent()).isEqualTo("You are a helpful assistant"); + assertThat(request.getStream()).isTrue(); + } + + @Test + void fromLlmRequest_appliesConfigParameters() throws Exception { + SarvamAiConfig config = + SarvamAiConfig.builder() + .apiKey("key") + .temperature(0.7) + .topP(0.9) + .maxTokens(100) + .reasoningEffort("high") + .wikiGrounding(true) + .build(); + + LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.builder().role("user").parts(Part.fromText("test")).build())) + .build(); + + ChatRequest request = ChatRequest.fromLlmRequest("sarvam-m", llmRequest, config, false); + + assertThat(request.getTemperature()).isWithin(0.001).of(0.7); + assertThat(request.getTopP()).isWithin(0.001).of(0.9); + assertThat(request.getMaxTokens()).isEqualTo(100); + assertThat(request.getReasoningEffort()).isEqualTo("high"); + assertThat(request.getWikiGrounding()).isTrue(); + } + + @Test + void serialization_excludesNullFields() throws Exception { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").build(); + LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.builder().role("user").parts(Part.fromText("Hi")).build())) + .build(); + + ChatRequest request = ChatRequest.fromLlmRequest("sarvam-m", llmRequest, config, false); + String json = objectMapper.writeValueAsString(request); + + assertThat(json).doesNotContain("temperature"); + assertThat(json).doesNotContain("stream"); + assertThat(json).doesNotContain("wiki_grounding"); + assertThat(json).contains("\"model\":\"sarvam-m\""); + assertThat(json).contains("\"messages\""); + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java new file mode 100644 index 000000000..8fca0ee6f --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.stt; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.adk.transcription.TranscriptionConfig; +import com.google.adk.transcription.TranscriptionException; +import com.google.adk.transcription.TranscriptionResult; +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** @author Sandeep Belgavi */ +class SarvamSttServiceTest { + + private MockWebServer server; + private SarvamSttService sttService; + + @BeforeEach + void setUp() throws IOException { + server = new MockWebServer(); + server.start(); + + SarvamAiConfig config = + SarvamAiConfig.builder() + .apiKey("test-stt-key") + .sttEndpoint(server.url("/speech-to-text").toString()) + .sttModel("saaras:v3") + .sttMode("transcribe") + .sttLanguageCode("hi-IN") + .build(); + + sttService = new SarvamSttService(config, new OkHttpClient()); + } + + @AfterEach + void tearDown() throws IOException { + server.shutdown(); + } + + @Test + void transcribe_success() throws TranscriptionException, InterruptedException { + server.enqueue( + new MockResponse() + .setBody( + "{\"request_id\":\"req-123\",\"transcript\":\"नमस्ते\",\"language_code\":\"hi-IN\"}")); + + TranscriptionConfig requestConfig = + TranscriptionConfig.builder() + .endpoint(server.url("/speech-to-text").toString()) + .language("hi-IN") + .build(); + + TranscriptionResult result = sttService.transcribe(new byte[] {1, 2, 3}, requestConfig); + + assertThat(result.getText()).isEqualTo("नमस्ते"); + assertThat(result.getLanguage().orElse("")).isEqualTo("hi-IN"); + + RecordedRequest recorded = server.takeRequest(5, TimeUnit.SECONDS); + assertThat(recorded.getHeader("api-subscription-key")).isEqualTo("test-stt-key"); + } + + @Test + void transcribe_serverError_throwsException() { + server.enqueue(new MockResponse().setResponseCode(500).setBody("Server error")); + + TranscriptionConfig requestConfig = + TranscriptionConfig.builder() + .endpoint(server.url("/speech-to-text").toString()) + .language("hi-IN") + .build(); + + assertThrows( + TranscriptionException.class, + () -> sttService.transcribe(new byte[] {1, 2, 3}, requestConfig)); + } + + @Test + void isAvailable_returnsTrue() { + assertThat(sttService.isAvailable()).isTrue(); + } + + @Test + void getServiceType_returnsSarvam() { + assertThat(sttService.getServiceType().getValue()).isEqualTo("sarvam"); + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java new file mode 100644 index 000000000..922cc8572 --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java @@ -0,0 +1,106 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.tts; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.adk.models.sarvamai.SarvamAiException; +import java.io.IOException; +import java.util.Base64; +import java.util.concurrent.TimeUnit; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** @author Sandeep Belgavi */ +class SarvamTtsServiceTest { + + private MockWebServer server; + private SarvamTtsService ttsService; + + @BeforeEach + void setUp() throws IOException { + server = new MockWebServer(); + server.start(); + + SarvamAiConfig config = + SarvamAiConfig.builder() + .apiKey("test-tts-key") + .ttsEndpoint(server.url("/text-to-speech").toString()) + .ttsModel("bulbul:v3") + .ttsSpeaker("shubh") + .build(); + + ttsService = new SarvamTtsService(config, new OkHttpClient()); + } + + @AfterEach + void tearDown() throws IOException { + server.shutdown(); + } + + @Test + void synthesize_success() throws InterruptedException { + byte[] expectedAudio = "fake-audio-data".getBytes(); + String base64Audio = Base64.getEncoder().encodeToString(expectedAudio); + String responseBody = + String.format("{\"request_id\":\"req-456\",\"audios\":[\"%s\"]}", base64Audio); + + server.enqueue(new MockResponse().setBody(responseBody)); + + byte[] audio = ttsService.synthesize("Hello world", "en-IN"); + + assertThat(audio).isEqualTo(expectedAudio); + + RecordedRequest recorded = server.takeRequest(5, TimeUnit.SECONDS); + assertThat(recorded.getHeader("api-subscription-key")).isEqualTo("test-tts-key"); + String body = recorded.getBody().readUtf8(); + assertThat(body).contains("\"model\":\"bulbul:v3\""); + assertThat(body).contains("\"speaker\":\"shubh\""); + assertThat(body).contains("\"target_language_code\":\"en-IN\""); + } + + @Test + void synthesize_serverError_throwsException() { + server.enqueue(new MockResponse().setResponseCode(500).setBody("Server error")); + + assertThrows(SarvamAiException.class, () -> ttsService.synthesize("Hello", "en-IN")); + } + + @Test + void synthesize_emptyAudio_throwsException() { + server.enqueue(new MockResponse().setBody("{\"request_id\":\"req-789\",\"audios\":[]}")); + + assertThrows(SarvamAiException.class, () -> ttsService.synthesize("Hello", "en-IN")); + } + + @Test + void synthesize_nullText_throwsNpe() { + assertThrows(NullPointerException.class, () -> ttsService.synthesize(null, "en-IN")); + } + + @Test + void isAvailable_returnsTrue() { + assertThat(ttsService.isAvailable()).isTrue(); + } +} diff --git a/core/pom.xml b/core/pom.xml index 157ee2dc8..37db191d2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -167,6 +167,12 @@ wiremock-jre8 test + + com.squareup.okhttp3 + mockwebserver + 4.12.0 + test + io.opentelemetry opentelemetry-api @@ -321,4 +327,4 @@ - + \ No newline at end of file diff --git a/core/src/main/java/com/google/adk/models/GptOssLlm.java b/core/src/main/java/com/google/adk/models/GptOssLlm.java index 331203ac6..895aba540 100644 --- a/core/src/main/java/com/google/adk/models/GptOssLlm.java +++ b/core/src/main/java/com/google/adk/models/GptOssLlm.java @@ -100,16 +100,16 @@ public GptOssLlm(String modelName) { * @param modelName The name of the GPT OSS model to use (e.g., "gpt-oss-4"). * @param vertexCredentials The Vertex AI credentials to access the model. */ -// public GptOssLlm(String modelName, VertexCredentials vertexCredentials) { -// super(modelName); -// Objects.requireNonNull(vertexCredentials, "vertexCredentials cannot be null"); -// Client.Builder apiClientBuilder = -// Client.builder().httpOptions(HttpOptions.builder().headers(TRACKING_HEADERS).build()); -// vertexCredentials.project().ifPresent(apiClientBuilder::project); -// vertexCredentials.location().ifPresent(apiClientBuilder::location); -// vertexCredentials.credentials().ifPresent(apiClientBuilder::credentials); -// this.apiClient = apiClientBuilder.build(); -// } + // public GptOssLlm(String modelName, VertexCredentials vertexCredentials) { + // super(modelName); + // Objects.requireNonNull(vertexCredentials, "vertexCredentials cannot be null"); + // Client.Builder apiClientBuilder = + // Client.builder().httpOptions(HttpOptions.builder().headers(TRACKING_HEADERS).build()); + // vertexCredentials.project().ifPresent(apiClientBuilder::project); + // vertexCredentials.location().ifPresent(apiClientBuilder::location); + // vertexCredentials.credentials().ifPresent(apiClientBuilder::credentials); + // this.apiClient = apiClientBuilder.build(); + // } /** * Returns a new Builder instance for constructing GptOssLlm objects. Note that when building a @@ -165,8 +165,7 @@ public GptOssLlm build() { if (apiClient != null) { return new GptOssLlm(modelName, apiClient); - } - else { + } else { return new GptOssLlm( modelName, Client.builder() @@ -354,4 +353,4 @@ public BaseLlmConnection connect(LlmRequest llmRequest) { return new GeminiLlmConnection(apiClient, effectiveModelName, liveConnectConfig); } -} \ No newline at end of file +} diff --git a/core/src/main/java/com/google/adk/models/SarvamBaseLM.java b/core/src/main/java/com/google/adk/models/SarvamBaseLM.java new file mode 100644 index 000000000..487dad652 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/SarvamBaseLM.java @@ -0,0 +1,745 @@ +package com.google.adk.models; + +import static com.google.adk.models.RedbusADG.cleanForIdentifierPattern; +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.Part; +import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Flowable; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.json.JSONArray; +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * BaseLlm implementation for Sarvam AI models. + * + *

Sarvam AI exposes an OpenAI-compatible chat completions API. The base URL is read from the + * {@code SARVAM_API_BASE} environment variable (default {@code https://api.sarvam.ai/v1}) and the + * API key from {@code SARVAM_API_KEY}. + * + * @author Sandeep Belgavi + */ +public class SarvamBaseLM extends BaseLlm { + + public static final String SARVAM_API_BASE_ENV = "SARVAM_API_BASE"; + public static final String SARVAM_API_KEY_ENV = "SARVAM_API_KEY"; + private static final String DEFAULT_BASE_URL = "https://api.sarvam.ai/v1"; + private static final int CONNECT_TIMEOUT_MS = 30_000; + private static final int READ_TIMEOUT_MS = 120_000; + + private static final ObjectMapper OBJECT_MAPPER = + new ObjectMapper().registerModule(new Jdk8Module()); + + private final String baseUrl; + private static final Logger logger = LoggerFactory.getLogger(SarvamBaseLM.class); + + private static final String CONTINUE_OUTPUT_MESSAGE = + "Continue output. DO NOT look at this line. ONLY look at the content before this line and" + + " system instruction."; + + public SarvamBaseLM(String model) { + super(model); + this.baseUrl = null; + warnIfApiKeyMissing(); + } + + public SarvamBaseLM(String model, String baseUrl) { + super(model); + this.baseUrl = baseUrl; + warnIfApiKeyMissing(); + } + + private void warnIfApiKeyMissing() { + String apiKey = System.getenv(SARVAM_API_KEY_ENV); + if (apiKey == null || apiKey.isBlank()) { + logger.warn( + "SARVAM_API_KEY environment variable is not set. " + + "Sarvam API calls for model '{}' will fail with 401 Unauthorized.", + model()); + } + } + + private String resolveBaseUrl() { + if (baseUrl != null) { + return baseUrl; + } + String envUrl = System.getenv(SARVAM_API_BASE_ENV); + return envUrl != null ? envUrl : DEFAULT_BASE_URL; + } + + private String resolveApiKey() { + return System.getenv(SARVAM_API_KEY_ENV); + } + + @Override + public Flowable generateContent(LlmRequest llmRequest, boolean stream) { + if (stream) { + return generateContentStream(llmRequest); + } + + List contents = ensureLastContentIsUser(llmRequest.contents()); + + String systemText = extractSystemText(llmRequest); + JSONArray messages = buildMessages(systemText, contents); + JSONArray functions = buildTools(llmRequest); + + boolean lastRespToolExecuted = + Iterables.getLast(Iterables.getLast(contents).parts().get()).functionResponse().isPresent(); + + float temperature = + llmRequest.config().flatMap(GenerateContentConfig::temperature).orElse(0.7f); + Optional maxTokens = + llmRequest.config().flatMap(GenerateContentConfig::maxOutputTokens); + + JSONObject response = + callChatCompletions( + this.model(), + messages, + lastRespToolExecuted ? null : (functions.length() > 0 ? functions : null), + temperature, + maxTokens.orElse(-1), + false); + + GenerateContentResponseUsageMetadata usageMetadata = extractUsageMetadata(response); + + JSONArray choices = response.optJSONArray("choices"); + if (choices == null || choices.length() == 0) { + logger.error("Sarvam API returned no choices: {}", response); + return Flowable.just( + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .build()); + } + + JSONObject message = choices.getJSONObject(0).getJSONObject("message"); + List parts = openAiMessageToParts(message); + + LlmResponse.Builder responseBuilder = LlmResponse.builder(); + + boolean hasFunctionCall = parts.stream().anyMatch(p -> p.functionCall().isPresent()); + if (hasFunctionCall) { + Part fcPart = parts.stream().filter(p -> p.functionCall().isPresent()).findFirst().get(); + responseBuilder.content( + Content.builder().role("model").parts(ImmutableList.of(fcPart)).build()); + } else { + responseBuilder.content( + Content.builder().role("model").parts(ImmutableList.copyOf(parts)).build()); + } + + if (usageMetadata != null) { + responseBuilder.usageMetadata(usageMetadata); + } + + return Flowable.just(responseBuilder.build()); + } + + private Flowable generateContentStream(LlmRequest llmRequest) { + List contents = ensureLastContentIsUser(llmRequest.contents()); + + String systemText = extractSystemText(llmRequest); + JSONArray messages = buildMessages(systemText, contents); + JSONArray functions = buildTools(llmRequest); + + boolean lastRespToolExecuted = + Iterables.getLast(Iterables.getLast(contents).parts().get()).functionResponse().isPresent(); + + float temperature = + llmRequest.config().flatMap(GenerateContentConfig::temperature).orElse(0.7f); + Optional maxTokens = + llmRequest.config().flatMap(GenerateContentConfig::maxOutputTokens); + + final StringBuilder accumulatedText = new StringBuilder(); + final StringBuilder functionCallName = new StringBuilder(); + final StringBuilder functionCallArgs = new StringBuilder(); + final AtomicBoolean inFunctionCall = new AtomicBoolean(false); + final AtomicBoolean streamCompleted = new AtomicBoolean(false); + final AtomicInteger inputTokens = new AtomicInteger(0); + final AtomicInteger outputTokens = new AtomicInteger(0); + + return Flowable.generate( + () -> + callChatCompletionsStream( + this.model(), + messages, + lastRespToolExecuted ? null : (functions.length() > 0 ? functions : null), + temperature, + maxTokens.orElse(-1)), + (reader, emitter) -> { + try { + if (reader == null || streamCompleted.get()) { + emitter.onComplete(); + return; + } + + String line = reader.readLine(); + if (line == null) { + emitFinalStreamResponse( + emitter, + accumulatedText, + inFunctionCall, + functionCallName, + functionCallArgs, + inputTokens.get(), + outputTokens.get()); + emitter.onComplete(); + return; + } + + if (line.isEmpty()) { + return; + } + + if (line.equals("data: [DONE]")) { + streamCompleted.set(true); + emitFinalStreamResponse( + emitter, + accumulatedText, + inFunctionCall, + functionCallName, + functionCallArgs, + inputTokens.get(), + outputTokens.get()); + emitter.onComplete(); + return; + } + + if (!line.startsWith("data: ")) { + return; + } + + String jsonStr = line.substring(6); + JSONObject chunk; + try { + chunk = new JSONObject(jsonStr); + } catch (Exception parseEx) { + logger.warn("Failed to parse Sarvam SSE chunk: {}", jsonStr, parseEx); + return; + } + + if (chunk.has("usage") && !chunk.isNull("usage")) { + JSONObject usage = chunk.getJSONObject("usage"); + inputTokens.set(usage.optInt("prompt_tokens", 0)); + outputTokens.set(usage.optInt("completion_tokens", 0)); + } + + JSONArray choices = chunk.optJSONArray("choices"); + if (choices == null || choices.length() == 0) { + return; + } + + JSONObject choice = choices.getJSONObject(0); + JSONObject delta = choice.optJSONObject("delta"); + if (delta == null) { + return; + } + + if (delta.has("content") && !delta.isNull("content")) { + String text = delta.getString("content"); + if (!text.isEmpty()) { + accumulatedText.append(text); + emitter.onNext(createTextResponse(text, true)); + } + } + + if (delta.has("tool_calls")) { + inFunctionCall.set(true); + JSONArray toolCalls = delta.getJSONArray("tool_calls"); + if (toolCalls.length() > 0) { + JSONObject toolCall = toolCalls.getJSONObject(0); + JSONObject function = toolCall.optJSONObject("function"); + if (function != null) { + if (function.has("name") && !function.isNull("name")) { + functionCallName.append(function.getString("name")); + } + if (function.has("arguments") && !function.isNull("arguments")) { + functionCallArgs.append(function.getString("arguments")); + } + } + } + } + } catch (Exception e) { + logger.error("Error in Sarvam streaming", e); + emitter.onError(e); + } + }, + reader -> { + try { + if (reader != null) { + reader.close(); + } + } catch (IOException e) { + logger.error("Error closing stream reader", e); + } + }); + } + + private void emitFinalStreamResponse( + io.reactivex.rxjava3.core.Emitter emitter, + StringBuilder accumulatedText, + AtomicBoolean inFunctionCall, + StringBuilder functionCallName, + StringBuilder functionCallArgs, + int promptTokens, + int completionTokens) { + + GenerateContentResponseUsageMetadata usageMetadata = + buildUsageMetadata(promptTokens, completionTokens); + + if (inFunctionCall.get() && functionCallName.length() > 0) { + try { + String argsString = functionCallArgs.length() > 0 ? functionCallArgs.toString() : "{}"; + Map args = new JSONObject(argsString).toMap(); + FunctionCall fc = + FunctionCall.builder().name(functionCallName.toString()).args(args).build(); + Part part = Part.builder().functionCall(fc).build(); + + LlmResponse.Builder builder = + LlmResponse.builder() + .content(Content.builder().role("model").parts(ImmutableList.of(part)).build()); + if (usageMetadata != null) { + builder.usageMetadata(usageMetadata); + } + emitter.onNext(builder.build()); + } catch (Exception funcEx) { + logger.error("Error creating function call response from stream", funcEx); + } + } else if (accumulatedText.length() > 0) { + LlmResponse.Builder builder = + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(accumulatedText.toString())) + .build()) + .partial(false); + if (usageMetadata != null) { + builder.usageMetadata(usageMetadata); + } + emitter.onNext(builder.build()); + } + } + + // ========== Request Building ========== + + private List ensureLastContentIsUser(List contents) { + if (contents.isEmpty() || !Iterables.getLast(contents).role().orElse("").equals("user")) { + Content userContent = Content.fromParts(Part.fromText(CONTINUE_OUTPUT_MESSAGE)); + return Stream.concat(contents.stream(), Stream.of(userContent)).collect(toImmutableList()); + } + return contents; + } + + private String extractSystemText(LlmRequest llmRequest) { + return llmRequest + .config() + .flatMap(GenerateContentConfig::systemInstruction) + .flatMap(Content::parts) + .map( + parts -> + parts.stream() + .filter(p -> p.text().isPresent()) + .map(p -> p.text().get()) + .collect(Collectors.joining("\n"))) + .filter(text -> !text.isEmpty()) + .orElse(""); + } + + private JSONArray buildMessages(String systemText, List contents) { + JSONArray messages = new JSONArray(); + + if (!systemText.isEmpty()) { + JSONObject systemMsg = new JSONObject(); + systemMsg.put("role", "system"); + systemMsg.put("content", systemText); + messages.put(systemMsg); + } + + for (Content item : contents) { + String role = item.role().orElse("user"); + List parts = item.parts().orElse(ImmutableList.of()); + + if (parts.isEmpty()) { + JSONObject msg = new JSONObject(); + msg.put("role", role.equals("model") ? "assistant" : role); + msg.put("content", item.text()); + messages.put(msg); + continue; + } + + Part firstPart = parts.get(0); + + if (firstPart.functionResponse().isPresent()) { + JSONObject msg = new JSONObject(); + msg.put("role", "tool"); + msg.put("tool_call_id", firstPart.functionResponse().get().name().orElse("call_unknown")); + msg.put( + "content", + new JSONObject(firstPart.functionResponse().get().response().get()).toString()); + messages.put(msg); + } else if (firstPart.functionCall().isPresent()) { + // Assistant message that previously requested a tool call + FunctionCall fc = firstPart.functionCall().get(); + JSONObject msg = new JSONObject(); + msg.put("role", "assistant"); + msg.put("content", JSONObject.NULL); + + JSONArray toolCalls = new JSONArray(); + JSONObject toolCall = new JSONObject(); + toolCall.put("id", "call_" + fc.name().orElse("unknown")); + toolCall.put("type", "function"); + JSONObject function = new JSONObject(); + function.put("name", fc.name().orElse("")); + function.put("arguments", new JSONObject(fc.args().orElse(Map.of())).toString()); + toolCall.put("function", function); + toolCalls.put(toolCall); + msg.put("tool_calls", toolCalls); + + messages.put(msg); + } else { + JSONObject msg = new JSONObject(); + msg.put("role", role.equals("model") ? "assistant" : role); + msg.put("content", item.text()); + messages.put(msg); + } + } + return messages; + } + + private JSONArray buildTools(LlmRequest llmRequest) { + JSONArray functions = new JSONArray(); + llmRequest + .tools() + .forEach( + (name, baseTool) -> { + Optional declOpt = baseTool.declaration(); + if (declOpt.isEmpty()) { + logger.warn("Skipping tool '{}' with missing declaration.", baseTool.name()); + return; + } + + FunctionDeclaration decl = declOpt.get(); + Map funcMap = new HashMap<>(); + funcMap.put("name", cleanForIdentifierPattern(decl.name().get())); + funcMap.put("description", cleanForIdentifierPattern(decl.description().orElse(""))); + + Optional paramsOpt = decl.parameters(); + if (paramsOpt.isPresent()) { + Schema paramsSchema = paramsOpt.get(); + Map paramsMap = new HashMap<>(); + paramsMap.put("type", "object"); + + Optional> propsOpt = paramsSchema.properties(); + if (propsOpt.isPresent()) { + Map propsMap = new HashMap<>(); + propsOpt + .get() + .forEach( + (key, schema) -> { + Map schemaMap = + OBJECT_MAPPER.convertValue( + schema, new TypeReference>() {}); + normalizeTypeStrings(schemaMap); + propsMap.put(key, schemaMap); + }); + paramsMap.put("properties", propsMap); + } + + paramsSchema + .required() + .ifPresent(requiredList -> paramsMap.put("required", requiredList)); + funcMap.put("parameters", paramsMap); + } + + JSONObject toolWrapper = new JSONObject(); + toolWrapper.put("type", "function"); + toolWrapper.put("function", new JSONObject(funcMap)); + functions.put(toolWrapper); + }); + return functions; + } + + // ========== HTTP Transport ========== + + private JSONObject callChatCompletions( + String model, + JSONArray messages, + JSONArray tools, + float temperature, + int maxTokens, + boolean stream) { + try { + String apiUrl = resolveBaseUrl() + "/chat/completions"; + String apiKey = resolveApiKey(); + + JSONObject payload = new JSONObject(); + payload.put("model", model); + payload.put("messages", messages); + payload.put("temperature", temperature); + payload.put("stream", stream); + + if (maxTokens > 0) { + payload.put("max_tokens", maxTokens); + } + + if (tools != null && tools.length() > 0) { + payload.put("tools", tools); + payload.put("tool_choice", "auto"); + } + + String jsonString = payload.toString(); + logger.debug("Sarvam request payload size: {} bytes", jsonString.length()); + + HttpURLConnection conn = openConnection(apiUrl, apiKey); + conn.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length); + + try (OutputStream os = conn.getOutputStream(); + OutputStreamWriter writer = new OutputStreamWriter(os, "UTF-8")) { + writer.write(jsonString); + writer.flush(); + } + + int responseCode = conn.getResponseCode(); + logger.info("Sarvam response code: {} for model: {}", responseCode, model); + + InputStream inputStream = + (responseCode < 400) ? conn.getInputStream() : conn.getErrorStream(); + + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(inputStream, "UTF-8"))) { + StringBuilder sb = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + sb.append(line); + } + + if (responseCode >= 400) { + logger.error("Sarvam API error: status={} body={}", responseCode, sb); + return new JSONObject().put("error", sb.toString()); + } + + JSONObject responseJson = new JSONObject(sb.toString()); + conn.disconnect(); + return responseJson; + } + } catch (Exception ex) { + logger.error("Error calling Sarvam chat completions API", ex); + return new JSONObject(); + } + } + + private BufferedReader callChatCompletionsStream( + String model, JSONArray messages, JSONArray tools, float temperature, int maxTokens) { + try { + String apiUrl = resolveBaseUrl() + "/chat/completions"; + String apiKey = resolveApiKey(); + + JSONObject payload = new JSONObject(); + payload.put("model", model); + payload.put("messages", messages); + payload.put("temperature", temperature); + payload.put("stream", true); + + // Request token usage in streaming responses + JSONObject streamOptions = new JSONObject(); + streamOptions.put("include_usage", true); + payload.put("stream_options", streamOptions); + + if (maxTokens > 0) { + payload.put("max_tokens", maxTokens); + } + + if (tools != null && tools.length() > 0) { + payload.put("tools", tools); + payload.put("tool_choice", "auto"); + } + + String jsonString = payload.toString(); + + HttpURLConnection conn = openConnection(apiUrl, apiKey); + conn.setRequestProperty("Accept", "text/event-stream"); + conn.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length); + + try (OutputStream os = conn.getOutputStream(); + OutputStreamWriter writer = new OutputStreamWriter(os, "UTF-8")) { + writer.write(jsonString); + writer.flush(); + } + + int responseCode = conn.getResponseCode(); + logger.info("Sarvam streaming response code: {} for model: {}", responseCode, model); + + if (responseCode >= 200 && responseCode < 300) { + return new BufferedReader(new InputStreamReader(conn.getInputStream(), "UTF-8")); + } else { + try (InputStream errorStream = conn.getErrorStream(); + BufferedReader errorReader = + new BufferedReader(new InputStreamReader(errorStream, "UTF-8"))) { + StringBuilder errorResponse = new StringBuilder(); + String errorLine; + while ((errorLine = errorReader.readLine()) != null) { + errorResponse.append(errorLine); + } + logger.error("Sarvam streaming failed: status={} body={}", responseCode, errorResponse); + } + conn.disconnect(); + return null; + } + } catch (IOException ex) { + logger.error("Error in Sarvam streaming request", ex); + return null; + } + } + + private HttpURLConnection openConnection(String apiUrl, String apiKey) throws IOException { + URL url = new URL(apiUrl); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); + conn.setConnectTimeout(CONNECT_TIMEOUT_MS); + conn.setReadTimeout(READ_TIMEOUT_MS); + conn.setDoOutput(true); + if (apiKey != null && !apiKey.isEmpty()) { + conn.setRequestProperty("Authorization", "Bearer " + apiKey); + } + return conn; + } + + // ========== Response Parsing ========== + + private LlmResponse createTextResponse(String text, boolean partial) { + return LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText(text)).build()) + .partial(partial) + .build(); + } + + private GenerateContentResponseUsageMetadata extractUsageMetadata(JSONObject response) { + if (response == null || !response.has("usage")) { + return null; + } + try { + JSONObject usage = response.getJSONObject("usage"); + int promptTokens = usage.optInt("prompt_tokens", 0); + int completionTokens = usage.optInt("completion_tokens", 0); + int totalTokens = usage.optInt("total_tokens", promptTokens + completionTokens); + + if (totalTokens > 0 || promptTokens > 0 || completionTokens > 0) { + logger.info( + "Sarvam token usage: prompt={}, completion={}, total={}", + promptTokens, + completionTokens, + totalTokens); + return GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(promptTokens) + .candidatesTokenCount(completionTokens) + .totalTokenCount(totalTokens) + .build(); + } + } catch (Exception e) { + logger.warn("Failed to parse token usage from Sarvam response", e); + } + return null; + } + + private GenerateContentResponseUsageMetadata buildUsageMetadata( + int promptTokens, int completionTokens) { + int totalTokens = promptTokens + completionTokens; + if (totalTokens > 0 || promptTokens > 0 || completionTokens > 0) { + return GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(promptTokens) + .candidatesTokenCount(completionTokens) + .totalTokenCount(totalTokens) + .build(); + } + return null; + } + + /** + * Converts an OpenAI-format message JSON to ADK Part(s). Handles both text content and tool_calls + * in a single message. + */ + static List openAiMessageToParts(JSONObject message) { + List parts = new ArrayList<>(); + + if (message.has("tool_calls")) { + JSONArray toolCalls = message.optJSONArray("tool_calls"); + if (toolCalls != null && toolCalls.length() > 0) { + JSONObject toolCall = toolCalls.getJSONObject(0); + JSONObject function = toolCall.optJSONObject("function"); + if (function != null) { + String name = function.optString("name", null); + String argsStr = function.optString("arguments", "{}"); + if (name != null) { + Map args = new JSONObject(argsStr).toMap(); + FunctionCall fc = FunctionCall.builder().name(name).args(args).build(); + parts.add(Part.builder().functionCall(fc).build()); + return parts; + } + } + } + } + + if (message.has("content") && !message.isNull("content")) { + parts.add(Part.builder().text(message.getString("content")).build()); + } else { + parts.add(Part.builder().text("").build()); + } + + return parts; + } + + @SuppressWarnings("unchecked") + private void normalizeTypeStrings(Map valueDict) { + if (valueDict == null) { + return; + } + if (valueDict.containsKey("type") && valueDict.get("type") instanceof String) { + valueDict.put("type", ((String) valueDict.get("type")).toLowerCase()); + } + if (valueDict.containsKey("items") && valueDict.get("items") instanceof Map) { + Map itemsMap = (Map) valueDict.get("items"); + normalizeTypeStrings(itemsMap); + if (itemsMap.containsKey("properties") && itemsMap.get("properties") instanceof Map) { + Map properties = (Map) itemsMap.get("properties"); + for (Object value : properties.values()) { + if (value instanceof Map) { + normalizeTypeStrings((Map) value); + } + } + } + } + } + + @Override + public BaseLlmConnection connect(LlmRequest llmRequest) { + return new GenericLlmConnection(this, llmRequest); + } +} diff --git a/core/src/main/java/com/google/adk/transcription/ServiceType.java b/core/src/main/java/com/google/adk/transcription/ServiceType.java index 2d4ae233f..98203eee4 100644 --- a/core/src/main/java/com/google/adk/transcription/ServiceType.java +++ b/core/src/main/java/com/google/adk/transcription/ServiceType.java @@ -33,7 +33,10 @@ public enum ServiceType { AZURE("azure"), /** AWS Transcribe (future). */ - AWS_TRANSCRIBE("aws_transcribe"); + AWS_TRANSCRIBE("aws_transcribe"), + + /** Sarvam AI transcription. */ + SARVAM("sarvam"); private final String value; diff --git a/core/src/main/java/com/google/adk/transcription/config/TranscriptionConfigLoader.java b/core/src/main/java/com/google/adk/transcription/config/TranscriptionConfigLoader.java index 7fb85fb9c..0de92d50b 100644 --- a/core/src/main/java/com/google/adk/transcription/config/TranscriptionConfigLoader.java +++ b/core/src/main/java/com/google/adk/transcription/config/TranscriptionConfigLoader.java @@ -23,19 +23,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Loads transcription configuration from environment variables. Follows 12-Factor App principles. - * - *

Transcription is an optional feature. If ADK_TRANSCRIPTION_ENDPOINT is not set, this returns - * Optional.empty(), allowing the framework to work without transcription. - * - * @author Sandeep Belgavi - * @since 2026-01-24 - */ +/** Loads transcription configuration from environment variables or system properties. */ public class TranscriptionConfigLoader { private static final Logger logger = LoggerFactory.getLogger(TranscriptionConfigLoader.class); - // Environment variable names + // Variable names private static final String ENDPOINT_ENV = "ADK_TRANSCRIPTION_ENDPOINT"; private static final String API_KEY_ENV = "ADK_TRANSCRIPTION_API_KEY"; private static final String LANGUAGE_ENV = "ADK_TRANSCRIPTION_LANGUAGE"; @@ -44,16 +36,23 @@ public class TranscriptionConfigLoader { private static final String SERVICE_TYPE_ENV = "ADK_TRANSCRIPTION_SERVICE_TYPE"; private static final String CHUNK_SIZE_ENV = "ADK_TRANSCRIPTION_CHUNK_SIZE_MS"; - /** - * Loads configuration from environment variables. Returns Optional.empty() if transcription is - * not configured (optional feature). - * - * @return Optional containing TranscriptionConfig if configured - */ + private static String getValue(String key) { + String val = System.getProperty(key); + if (val == null || val.isEmpty()) { + val = System.getenv(key); + } + return val; + } + public static Optional loadFromEnvironment() { - String endpoint = System.getenv(ENDPOINT_ENV); + String endpoint = getValue(ENDPOINT_ENV); + + // For Sarvam, we can default the endpoint if service type is sarvam + String serviceType = getValue(SERVICE_TYPE_ENV); + if ("sarvam".equalsIgnoreCase(serviceType) && (endpoint == null || endpoint.isEmpty())) { + endpoint = "https://api.sarvam.ai/speech-to-text"; + } - // Transcription is optional - return empty if not configured if (endpoint == null || endpoint.isEmpty()) { logger.debug("Transcription not configured ({} not set)", ENDPOINT_ENV); return Optional.empty(); @@ -61,20 +60,23 @@ public static Optional loadFromEnvironment() { TranscriptionConfig.Builder builder = TranscriptionConfig.builder().endpoint(endpoint); - // Optional: API Key - String apiKey = System.getenv(API_KEY_ENV); + String apiKey = getValue(API_KEY_ENV); + if (apiKey == null || apiKey.isEmpty()) { + apiKey = getValue("SARVAM_API_KEY"); + } + if (apiKey != null && !apiKey.isEmpty()) { builder.apiKey(apiKey); } - // Optional: Language (default: auto) - String language = System.getenv(LANGUAGE_ENV); + String language = getValue(LANGUAGE_ENV); if (language != null && !language.isEmpty()) { builder.language(language); + } else if ("sarvam".equalsIgnoreCase(serviceType)) { + builder.language("hi-IN"); // Default for Sarvam POC } - // Optional: Timeout (default: 30 seconds) - String timeoutStr = System.getenv(TIMEOUT_ENV); + String timeoutStr = getValue(TIMEOUT_ENV); if (timeoutStr != null) { try { int timeoutSeconds = Integer.parseInt(timeoutStr); @@ -86,43 +88,12 @@ public static Optional loadFromEnvironment() { } } - // Optional: Max retries (default: 3) - String maxRetriesStr = System.getenv(MAX_RETRIES_ENV); - if (maxRetriesStr != null) { - try { - int maxRetries = Integer.parseInt(maxRetriesStr); - if (maxRetries >= 0) { - builder.maxRetries(maxRetries); - } - } catch (NumberFormatException e) { - logger.warn("Invalid max retries value: {}, using default", maxRetriesStr); - } - } - - // Optional: Chunk size (default: 500ms) - String chunkSizeStr = System.getenv(CHUNK_SIZE_ENV); - if (chunkSizeStr != null) { - try { - int chunkSizeMs = Integer.parseInt(chunkSizeStr); - if (chunkSizeMs > 0) { - builder.chunkSizeMs(chunkSizeMs); - } - } catch (NumberFormatException e) { - logger.warn("Invalid chunk size value: {}, using default", chunkSizeStr); - } - } - - // Audio format (default: PCM 16kHz Mono) builder.audioFormat(AudioFormat.PCM_16KHZ_MONO); - - // Enable partial results for real-time streaming builder.enablePartialResults(true); TranscriptionConfig config = builder.build(); logger.info( - "Loaded transcription config: endpoint={}, service={}", - config.getEndpoint(), - System.getenv(SERVICE_TYPE_ENV)); + "Loaded transcription config: endpoint={}, service={}", config.getEndpoint(), serviceType); return Optional.of(config); } diff --git a/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java b/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java index c9c28d928..9260e7849 100644 --- a/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java +++ b/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java @@ -84,6 +84,11 @@ private static TranscriptionService createService(TranscriptionConfig config) { ServiceType serviceType = determineServiceType(config); switch (serviceType) { + case SARVAM: + throw new UnsupportedOperationException( + "Sarvam STT has moved to the contrib/sarvam-ai module. " + + "Use SarvamSttService from com.google.adk.models.sarvamai.stt instead."); + case WHISPER: return createWhisperService(config); diff --git a/core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java b/core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java new file mode 100644 index 000000000..ef3d6edb5 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java @@ -0,0 +1,177 @@ +package com.google.adk.models; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.genai.types.FunctionCall; +import com.google.genai.types.Part; +import java.util.List; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +/** + * @author Sandeep Belgavi + */ +public final class SarvamBaseLMTest { + + // ========== openAiMessageToParts tests ========== + + @Test + public void openAiMessageToParts_textContent_returnsTextPart() { + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("content", "Hello world"); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue("Hello world"); + assertThat(parts.get(0).functionCall()).isEmpty(); + } + + @Test + public void openAiMessageToParts_nullContent_returnsEmptyTextPart() { + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("content", JSONObject.NULL); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue(""); + } + + @Test + public void openAiMessageToParts_missingContent_returnsEmptyTextPart() { + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue(""); + } + + @Test + public void openAiMessageToParts_toolCall_returnsFunctionCallPart() { + JSONObject function = new JSONObject(); + function.put("name", "getBusSearch"); + function.put("arguments", "{\"source\":\"Bangalore\",\"dest\":\"Chennai\"}"); + + JSONObject toolCall = new JSONObject(); + toolCall.put("id", "call_abc123"); + toolCall.put("type", "function"); + toolCall.put("function", function); + + JSONArray toolCalls = new JSONArray(); + toolCalls.put(toolCall); + + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("content", JSONObject.NULL); + message.put("tool_calls", toolCalls); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + + FunctionCall fc = parts.get(0).functionCall().get(); + assertThat(fc.name()).hasValue("getBusSearch"); + assertThat(fc.args()).isPresent(); + assertThat(fc.args().get()).containsEntry("source", "Bangalore"); + assertThat(fc.args().get()).containsEntry("dest", "Chennai"); + } + + @Test + public void openAiMessageToParts_toolCallWithEmptyArgs_returnsFunctionCallWithEmptyMap() { + JSONObject function = new JSONObject(); + function.put("name", "getOffers"); + function.put("arguments", "{}"); + + JSONObject toolCall = new JSONObject(); + toolCall.put("id", "call_xyz"); + toolCall.put("type", "function"); + toolCall.put("function", function); + + JSONArray toolCalls = new JSONArray(); + toolCalls.put(toolCall); + + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("tool_calls", toolCalls); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + assertThat(parts.get(0).functionCall().get().name()).hasValue("getOffers"); + assertThat(parts.get(0).functionCall().get().args().get()).isEmpty(); + } + + @Test + public void openAiMessageToParts_toolCallTakesPriorityOverContent() { + JSONObject function = new JSONObject(); + function.put("name", "search"); + function.put("arguments", "{}"); + + JSONObject toolCall = new JSONObject(); + toolCall.put("id", "call_1"); + toolCall.put("type", "function"); + toolCall.put("function", function); + + JSONArray toolCalls = new JSONArray(); + toolCalls.put(toolCall); + + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("content", "I'll search for you"); + message.put("tool_calls", toolCalls); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + assertThat(parts.get(0).functionCall().get().name()).hasValue("search"); + } + + @Test + public void openAiMessageToParts_emptyToolCalls_fallsBackToContent() { + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("content", "Here are the results"); + message.put("tool_calls", new JSONArray()); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue("Here are the results"); + } + + // ========== Constructor / config tests ========== + + @Test + public void constructor_setsModelName() { + SarvamBaseLM llm = new SarvamBaseLM("sarvam-m"); + assertThat(llm.model()).isEqualTo("sarvam-m"); + } + + @Test + public void constructor_withBaseUrl_setsModelName() { + SarvamBaseLM llm = new SarvamBaseLM("sarvam-m", "https://custom.api.com/v1"); + assertThat(llm.model()).isEqualTo("sarvam-m"); + } + + @Test + public void connect_returnsGenericLlmConnection() { + SarvamBaseLM llm = new SarvamBaseLM("sarvam-m"); + LlmRequest request = LlmRequest.builder().build(); + + BaseLlmConnection connection = llm.connect(request); + + assertThat(connection).isInstanceOf(GenericLlmConnection.class); + } +} diff --git a/dev/src/main/resources/application.properties b/dev/src/main/resources/application.properties index 0ff0eb627..a7a8dee80 100644 --- a/dev/src/main/resources/application.properties +++ b/dev/src/main/resources/application.properties @@ -1,11 +1,15 @@ -# Spring Boot Server Configuration -# Author: Sandeep Belgavi -# Date: January 24, 2026 +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -# Spring Boot server port (for Spring SSE endpoint) -server.port=9086 - -# HttpServer SSE Configuration (default SSE endpoint) -adk.httpserver.sse.enabled=true -adk.httpserver.sse.port=9085 -adk.httpserver.sse.host=0.0.0.0 +adk.httpserver.sse.port=9999 \ No newline at end of file diff --git a/pom.xml b/pom.xml index 6a1aa5af5..46971625c 100644 --- a/pom.xml +++ b/pom.xml @@ -32,6 +32,7 @@ maven_plugin contrib/langchain4j contrib/spring-ai + contrib/sarvam-ai contrib/samples contrib/firestore-session-service tutorials/city-time-weather