diff --git a/src/main/java/io/openliberty/tools/common/ai/ChatAgent.java b/src/main/java/io/openliberty/tools/common/ai/ChatAgent.java index f7cd42d6..5723dfb0 100644 --- a/src/main/java/io/openliberty/tools/common/ai/ChatAgent.java +++ b/src/main/java/io/openliberty/tools/common/ai/ChatAgent.java @@ -33,6 +33,11 @@ public ChatAgent(int memoryId) throws Exception { getAssistant(); } + public void clearAssistant() { + resetChat(); + this.assistant = null; + } + public Assistant getAssistant() throws Exception { if (assistant == null) { AiServices builder = diff --git a/src/main/java/io/openliberty/tools/common/ai/util/ModelBuilder.java b/src/main/java/io/openliberty/tools/common/ai/util/ModelBuilder.java index 884bcf26..4ff026dd 100644 --- a/src/main/java/io/openliberty/tools/common/ai/util/ModelBuilder.java +++ b/src/main/java/io/openliberty/tools/common/ai/util/ModelBuilder.java @@ -18,7 +18,6 @@ import static java.time.Duration.ofSeconds; import java.util.List; -import java.util.Scanner; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.ollama.OllamaChatModel; @@ -29,7 +28,7 @@ public class ModelBuilder { public static final String OLLAMA = "Ollama"; - private static String OLLAMA_BASE_URL = System.getenv("OLLAMA_BASE_URL"); + private static String OLLAMA_BASE_URL = null; private static String model; private static String provider; @@ -40,16 +39,6 @@ public class ModelBuilder { private Integer MAX_MESSAGES; private Double TEMPERATURE; - private static Scanner scan = new Scanner(System.in); - - private static void modelSelection() { - System.out.print("\nPress enter to use the default model " + Utils.bold(model) + " or type in a model name: "); - String modelName = scan.nextLine().trim(); - if (!modelName.isBlank()) { - model = modelName; - } - } - public static ChatModel chatModel() { return chatModel; } @@ -60,43 +49,34 @@ public static void cleanInputProvider() { provider = null; } - public static boolean promptInputProvider() { - String provider = ""; - while (!(("ollama".equals(provider) || "1".equals(provider)) || ("github".equals(provider)|| "2".equals(provider)) || - ("mistral".equals(provider) || "3".equals(provider)) || ("gemini".equals(provider) || "4".equals(provider)))) { - System.out.print("\n\nSelect a provider (1:Ollama, 2:Github, 3:Mistral, or press Enter to quit): "); - provider = scan.nextLine().toLowerCase().trim(); - if (provider.isEmpty()) { - System.out.println("Skipped to enable AI mode."); - return false; - } - } + public static boolean selectInputProvider() throws Exception { - Boolean validResponse = false; - while (!validResponse) { - - String apiKeyOrUrl; - if ("ollama".equals(provider)|| "1".equals(provider)) { - System.out.print("\nEnter a valid Ollama URL: "); - apiKeyOrUrl = scan.nextLine().trim(); - } else { - apiKeyOrUrl = Utils.getReader().readLine("\nEnter a valid API key or press Enter to quit: ", '*').trim(); - } - - if (apiKeyOrUrl.isEmpty()) { - System.out.println("Skipped to enable AI mode."); - return false; - } - - if (!provider.isEmpty()) { - if (("ollama".equals(provider) || "1".equals(provider)) && apiKeyOrUrl.toLowerCase().startsWith("http")) { - OLLAMA_BASE_URL = apiKeyOrUrl; - findModel(); - modelSelection(); - validResponse = true; - } else { - System.out.println("[ERROR] Enter a valid combination of provider and API key."); + provider = OLLAMA; + OLLAMA_BASE_URL = System.getProperty("ollama.base.url"); + + if (OLLAMA_BASE_URL == null || OLLAMA_BASE_URL.isBlank()) { + OLLAMA_BASE_URL = "http://localhost:11434"; + } + + model = System.getProperty("chat.model.id"); + + if (model == null || model.isBlank()) { + Ollama ollamaAPI = new Ollama(OLLAMA_BASE_URL); + List models; + try { + models = ollamaAPI.listModels(); + for (Model installedModel : models) { + String modelName = installedModel.getModelName(); + if (modelName.equals("gpt-oss")) { + model = modelName; + break; + } + } + if (model == null || model.isBlank()) { + return false; } + } catch (Exception exception) { + throw exception; } } return true; @@ -134,7 +114,6 @@ public static void findModel() { public ChatModel getChatModel() { if (chatModel == null) { - findModel(); if (provider.equals(OLLAMA)) { chatModel = OllamaChatModel.builder() .baseUrl(OLLAMA_BASE_URL) diff --git a/src/main/java/io/openliberty/tools/common/plugins/util/DevUtil.java b/src/main/java/io/openliberty/tools/common/plugins/util/DevUtil.java index e5f8a3bf..206cf941 100644 --- a/src/main/java/io/openliberty/tools/common/plugins/util/DevUtil.java +++ b/src/main/java/io/openliberty/tools/common/plugins/util/DevUtil.java @@ -29,6 +29,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.net.ConnectException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.MalformedURLException; @@ -2501,8 +2502,6 @@ public void runHotkeyReaderThread(ThreadPoolExecutor executor) { private void printDevModeMessages(boolean inputUnavailable, boolean startup) throws PluginExecutionException { // the following will be printed only on startup or restart if (startup) { - getChatAgent(); - // print barrier header info(formatAttentionBarrier()); @@ -2529,19 +2528,19 @@ private void printDevModeMessages(boolean inputUnavailable, boolean startup) thr } if (startup) { printPortInfo(true); - if (getChatAgent() == null) { - AIMode = false; - } else { + if (isChatAgentAvailable()) { AIMode = true; + printAIStatus(); + } else { + AIMode = false; } - printAIStatus(); // print barrier footer info(formatAttentionBarrier()); } } private void printAIStatus() { - if (AIMode == false || getChatAgent() == null) { + if (AIMode == false || !isChatAgentAvailable()) { return; } info(formatAttentionMessage("")); @@ -2775,13 +2774,43 @@ private boolean incrementGenerateFeatures() { private ChatAgent getChatAgent() { if (chatAgent == null) { try { - chatAgent = new ChatAgent(1); + this.chatAgent = new ChatAgent(1); } catch (Exception e) { - debug(e.getMessage()); } } return chatAgent; } + + private boolean isChatAgentAvailable() { + return chatAgent != null; + } + + private void resetChatAgent() { + if (chatAgent != null) { + chatAgent.clearAssistant(); + this.chatAgent = null; + } + } + + private boolean isChatAgentValid() { + if (chatAgent == null) { + try { + int memoryIDTest = 2; + chatAgent = new ChatAgent(memoryIDTest); + String response = chatAgent.chat("Test message"); + if (response != null && !response.isBlank()) { + return true; + } + } catch (RuntimeException runtimeException) { + return false; + } catch (ConnectException connectException) { + return false; + } catch (Exception e) { + return false; + } + } + return true; + } private void chat(String message) { String response = null; @@ -2883,29 +2912,42 @@ private void readInput() { } } else if (a.isPressed(line)) { if (AIMode) { - info("AI mode has been turned off."); + if (isChatAgentAvailable()) { + resetChatAgent(); + } AIMode = false; + info("AI mode has been turned off."); } else { - if (getChatAgent() == null) { - if (!ModelBuilder.promptInputProvider()) { - continue; + if (!isChatAgentAvailable()) { + try { + getChatAgent(); + boolean validSetModelProvider = ModelBuilder.selectInputProvider(); + System.out.print("setting up...\n"); + boolean validConnectionToChatAgent = isChatAgentValid(); + if (validSetModelProvider && validConnectionToChatAgent) { + AIMode = true; + info(formatAttentionBarrier()); + printAIStatus(); + info(formatAttentionMessage("")); + info(formatAttentionBarrier()); + continue; + } else if (validSetModelProvider == false) { + warn("Failed to enable AI mode. Could not find the model. Ensure the model is available, or provide a valid model through the chat.model.id system property, when start the dev mode." ); + } else if (!validConnectionToChatAgent) { + warn("Failed to enable AI mode. Ensure Ollama is installed and pull the gpt-oss model. Otherwise, provide a valid Ollama URL through the ollama.base.url and model through the chat.model.id system properties, when start the dev mode." ); + } + } catch (Exception exception) { + warn("Failed to enable AI mode. Ensure Ollama is installed and pull the gpt-oss model. Otherwise, provide a valid Ollama URL through the ollama.base.url and model through the chat.model.id system properties, when start the dev mode." ); } - System.out.print("\rsetting up..."); - if (getChatAgent() == null) { + if (!isChatAgentAvailable()) { AIMode = false; ModelBuilder.cleanInputProvider(); - System.out.print("\r \r"); - System.out.println("Failed to enable AI mode."); continue; } System.out.print("\r \r"); } - AIMode = true; - info(formatAttentionBarrier()); - printAIStatus(); - info(formatAttentionMessage("")); - info(formatAttentionBarrier()); } + } else if ((t.isPressed(line) && isChangeOnDemandTestsAction()) || (enter.isPressed(line) && !isChangeOnDemandTestsAction())) { debug("Detected test command. Running tests... "); if (isMultiModuleProject()) { @@ -2917,7 +2959,7 @@ private void readInput() { } else if (enter.isPressed(line) && isChangeOnDemandTestsAction()) { warn("Unrecognized command: Enter. To see the help menu, type 'h' and press Enter."); } else if (AIMode && line.startsWith("[")) { - if (getChatAgent() == null) { + if (!isChatAgentAvailable()) { warn("AI could not be started, ensure the API/URL and model is correct"); } if (line.trim().startsWith("[")) {