Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 179 additions & 23 deletions core/src/main/java/com/google/adk/models/BedrockBaseLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -55,10 +56,51 @@
*/
public class BedrockBaseLM extends BaseLlm {

// Use a constant for the environment variable name
public static final String BEDROCK_ENV_VAR = "BEDROCK_URL";
public String D_URL = null;

/**
* Bearer token env vars (checked in order). BEDROCK_API_KEY (FMIS) preferred - works for both
* Converse + foundation-models in ap-south-1/ap-southeast-1. AWS_BEARER_TOKEN_BEDROCK is legacy.
*/
private static final String[] BEARER_TOKEN_ENV_VARS = {
"BEDROCK_API_KEY", // FMIS key - works for Converse + foundation-models
"AWS_BEARER_TOKEN_BEDROCK", // Legacy runtime token
"BEDROCK_BEARER_TOKEN",
"BEDROCK_TOKEN"
};

/** Returns Bearer token from env. Prefers BEDROCK_API_KEY (FMIS) when available. */
public static String getBearerToken() {
for (String name : BEARER_TOKEN_ENV_VARS) {
String v = System.getenv(name);
if (v != null && !v.isBlank()) return v;
}
return null;
}

/** Base URL for Converse API. Uses BEDROCK_URL, or BEDROCK_REGION, or default ap-south-1. */
private static String getBedrockBaseUrl(String overrideUrl) {
if (overrideUrl != null && !overrideUrl.isBlank()) return overrideUrl;
String url = System.getenv(BEDROCK_ENV_VAR);
if (url != null && !url.isBlank()) return url;
String region = System.getenv("BEDROCK_REGION");
if (region == null || region.isBlank()) region = "ap-south-1";
return "https://bedrock-runtime." + region + ".amazonaws.com";
}

/**
* Returns the full Converse API URL. Handles BEDROCK_URL with or without trailing /model to avoid
* double /model/ in path.
*/
private static String buildConverseUrl(String baseUrl, String model) {
String base = baseUrl == null ? "" : baseUrl.trim().replaceAll("/+$", "");
if (base.endsWith("/model")) {
return base + "/" + model + "/converse";
}
return base + "/model/" + model + "/converse";
}

// Corrected the logger name to use OllamaBaseLM.class
private static final Logger logger = LoggerFactory.getLogger(BedrockBaseLM.class);

Expand Down Expand Up @@ -304,7 +346,30 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre
logger.debug("Usage metadata parsing failed (non-critical)", e);
}

JSONObject responseQuantum = agentresponse.getJSONObject("output").getJSONObject("message");
// Bedrock Converse API: output.message, or message/Message at top level.
// Message (capital M) can be a String in error responses (e.g. "Authentication failed").
JSONObject responseQuantum = extractMessageObject(agentresponse);
if (responseQuantum == null) {
String detail = "Response keys: " + agentresponse.keySet();
if (agentresponse.has("Output")) {
try {
JSONObject out = agentresponse.getJSONObject("Output");
detail += ", Output keys: " + out.keySet();
} catch (Exception e) {
detail += ", Output: (not JSONObject)";
}
} else if (agentresponse.has("output")) {
try {
JSONObject out = agentresponse.getJSONObject("output");
detail += ", output keys: " + out.keySet();
} catch (Exception e) {
detail += ", output: (not JSONObject)";
}
}
throw new IllegalStateException(
"Unexpected Bedrock response: missing output/Output.message, message, or Message. "
+ detail);
}

// Check if tool call is required
// Tools call
Expand Down Expand Up @@ -335,6 +400,59 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre
return Flowable.just(responseBuilder.build());
}

/** Gets a value from JSONObject using case-insensitive key match. */
private static Object getKeyIgnoreCase(JSONObject obj, String... keys) {
Iterator<String> it = obj.keys();
while (it.hasNext()) {
String k = it.next();
for (String target : keys) {
if (k.equalsIgnoreCase(target)) return obj.get(k);
}
}
return null;
}

/**
* Extracts the message content object from a Bedrock response. Handles both success (message
* object) and error (Message as String) responses. Supports output/Output (AWS PascalCase) and
* case-insensitive keys for FMIS/global endpoints.
*/
private static JSONObject extractMessageObject(JSONObject response) {
Object msg = null;
String errMsg = null;
JSONObject outputObj =
response.has("output")
? response.getJSONObject("output")
: response.has("Output") ? response.optJSONObject("Output") : null;
if (outputObj != null) {
msg = getKeyIgnoreCase(outputObj, "message", "Message");
if (msg == null
&& outputObj.has("choices")
&& outputObj.getJSONArray("choices").length() > 0) {
JSONObject first = outputObj.getJSONArray("choices").getJSONObject(0);
msg = getKeyIgnoreCase(first, "message", "Message");
}
if (msg == null && outputObj.has("content")) {
msg = outputObj;
}
}
if (msg == null) {
msg = getKeyIgnoreCase(response, "message", "Message");
}
if (msg instanceof String) {
errMsg = (String) msg;
} else if (msg instanceof JSONObject) {
return (JSONObject) msg;
}
if (errMsg != null) {
throw new IllegalStateException("Bedrock API error: " + errMsg);
}
if (outputObj != null) {
logger.debug("Bedrock Output keys (extraction failed): {}", outputObj.keySet());
}
return null;
}

public Flowable<LlmResponse> generateContentStream(LlmRequest llmRequest) {
List<Content> contents = llmRequest.contents();
// Last content must be from the user, otherwise the model won't respond.
Expand Down Expand Up @@ -578,13 +696,31 @@ private Flowable<LlmResponse> createRobustStreamingResponse(
}

JSONObject message = null;
if (responseJson.has("output")) {
JSONObject output = responseJson.getJSONObject("output");
if (output.has("message")) {
message = output.getJSONObject("message");
Object msgVal = null;
JSONObject outputObj =
responseJson.has("output")
? responseJson.optJSONObject("output")
: responseJson.has("Output") ? responseJson.optJSONObject("Output") : null;
if (outputObj != null) {
msgVal = getKeyIgnoreCase(outputObj, "message", "Message");
if (msgVal == null
&& outputObj.has("choices")
&& outputObj.getJSONArray("choices").length() > 0) {
JSONObject first = outputObj.getJSONArray("choices").getJSONObject(0);
msgVal = getKeyIgnoreCase(first, "message", "Message");
}
if (msgVal == null && outputObj.has("content")) {
msgVal = outputObj;
}
} else if (responseJson.has("message")) {
message = responseJson.getJSONObject("message");
}
if (msgVal == null) {
msgVal = getKeyIgnoreCase(responseJson, "message", "Message");
}
if (msgVal instanceof String) {
emitter.onError(new IllegalStateException("Bedrock API error: " + msgVal));
return reader;
} else if (msgVal instanceof JSONObject) {
message = (JSONObject) msgVal;
}

// Accumulate all text from this response chunk
Expand Down Expand Up @@ -757,9 +893,14 @@ private LlmResponse createTextResponse(String text, boolean partial) {

public BufferedReader callLLMChatStream(String model, JSONArray messages, JSONArray tools) {
try {
String apiUrl =
(D_URL != null ? D_URL : System.getenv(BEDROCK_ENV_VAR)) + "/" + model + "/converse";
String AWS_BEARER_TOKEN_BEDROCK = System.getenv("AWS_BEARER_TOKEN_BEDROCK");
String bearerToken = getBearerToken();
if (bearerToken == null || bearerToken.isBlank()) {
throw new IllegalStateException(
"Bedrock Bearer token not found. Set one of: AWS_BEARER_TOKEN_BEDROCK, "
+ "BEDROCK_BEARER_TOKEN, BEDROCK_API_KEY, BEDROCK_TOKEN (e.g. in .bashrc)");
}
String baseUrl = getBedrockBaseUrl(D_URL);
String apiUrl = buildConverseUrl(baseUrl, model);
System.out.println("Using Bedrock URL: " + apiUrl);
JSONObject payload = new JSONObject();
// Model already encoded in path; omit 'model' field to avoid Unexpected field type errors
Expand All @@ -775,7 +916,7 @@ public BufferedReader callLLMChatStream(String model, JSONArray messages, JSONAr
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
connection.setRequestProperty("Content-Type", "application/json; charset=UTF-8");
connection.setRequestProperty("Authorization", "Bearer " + AWS_BEARER_TOKEN_BEDROCK);
connection.setRequestProperty("Authorization", "Bearer " + bearerToken);
connection.setDoOutput(true);
connection.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length);

Expand Down Expand Up @@ -885,11 +1026,17 @@ public static Part ollamaContentBlockToPart(JSONObject blockJson) {
*/
public JSONObject callLLMChat(String model, JSONArray messages, JSONArray tools) {
try {
String bearerToken = getBearerToken();
if (bearerToken == null || bearerToken.isBlank()) {
throw new IllegalStateException(
"Bedrock Bearer token not found. Set one of: AWS_BEARER_TOKEN_BEDROCK, "
+ "BEDROCK_BEARER_TOKEN, BEDROCK_API_KEY, BEDROCK_TOKEN (e.g. in .bashrc)");
}
String baseUrl = getBedrockBaseUrl(D_URL);
JSONObject responseJ = new JSONObject();
String apiUrl = D_URL != null ? D_URL : System.getenv(BEDROCK_ENV_VAR);
String AWS_BEARER_TOKEN_BEDROCK = System.getenv("AWS_BEARER_TOKEN_BEDROCK");
String apiUrl = buildConverseUrl(baseUrl, model);
JSONObject payload = new JSONObject();
payload.put("model", model);
// Model already in path; omit from payload to avoid "Unexpected field type" errors
payload.put("stream", false);
payload.put("messages", messages);
if (tools != null) {
Expand All @@ -901,7 +1048,7 @@ public JSONObject callLLMChat(String model, JSONArray messages, JSONArray tools)
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
connection.setRequestProperty("Content-Type", "application/json; charset=UTF-8");
connection.setRequestProperty("Authorization", "Bearer " + AWS_BEARER_TOKEN_BEDROCK);
connection.setRequestProperty("Authorization", "Bearer " + bearerToken);
connection.setDoOutput(true);
connection.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length);

Expand Down Expand Up @@ -958,9 +1105,12 @@ public static JSONObject callLLMChat(
boolean stream, String prompt, String model, JSONArray messages, JSONArray tools) {
JSONObject responseJ = new JSONObject();
try {
String apiUrl = System.getenv(BEDROCK_ENV_VAR);
String AWS_BEARER_TOKEN_BEDROCK = System.getenv(BEDROCK_ENV_VAR);
apiUrl = apiUrl + "/api/chat";
String bearerToken = getBearerToken();
if (bearerToken == null || bearerToken.isBlank()) {
throw new IllegalStateException(
"Bedrock Bearer token not found. Set AWS_BEARER_TOKEN_BEDROCK, BEDROCK_BEARER_TOKEN, etc.");
}
String apiUrl = getBedrockBaseUrl(null) + "/api/chat";
JSONObject payload = new JSONObject();
payload.put("model", model);
payload.put("stream", false);
Expand All @@ -978,7 +1128,7 @@ public static JSONObject callLLMChat(
// System.out.print("HTTP Connection to Ollama API: " + apiUrl.toString());
connection.setRequestMethod("POST");
connection.setRequestProperty("Content-Type", "application/json");
connection.setRequestProperty("Authorization", "Bearer " + AWS_BEARER_TOKEN_BEDROCK);
connection.setRequestProperty("Authorization", "Bearer " + bearerToken);
connection.setDoOutput(true);
connection.setFixedLengthStreamingMode(jsonString.getBytes().length);
try (DataOutputStream outputStream = new DataOutputStream(connection.getOutputStream())) {
Expand Down Expand Up @@ -1070,8 +1220,14 @@ public Flowable<JSONObject> generateContent(
return Flowable.create(
emitter -> {
try {
String apiUrl = D_URL != null ? D_URL : System.getenv(BEDROCK_ENV_VAR);
String AWS_BEARER_TOKEN_BEDROCK = System.getenv("AWS_BEARER_TOKEN_BEDROCK");
String bearerToken = getBearerToken();
if (bearerToken == null || bearerToken.isBlank()) {
emitter.onError(
new IllegalStateException(
"Bedrock Bearer token not found. Set AWS_BEARER_TOKEN_BEDROCK, BEDROCK_BEARER_TOKEN, etc."));
return;
}
String apiUrl = getBedrockBaseUrl(D_URL);
JSONObject payload = new JSONObject();
payload.put("messages", messages);
if (tools != null) {
Expand All @@ -1082,7 +1238,7 @@ public Flowable<JSONObject> generateContent(
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
connection.setRequestProperty("Content-Type", "application/json; charset=UTF-8");
connection.setRequestProperty("Authorization", "Bearer " + AWS_BEARER_TOKEN_BEDROCK);
connection.setRequestProperty("Authorization", "Bearer " + bearerToken);
connection.setDoOutput(true);
connection.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length);
try (OutputStream outputStream = connection.getOutputStream();
Expand Down
Loading