|
16 | 16 | package com.intellijava.core.controller; |
17 | 17 |
|
18 | 18 | import java.io.IOException; |
| 19 | +import java.util.ArrayList; |
19 | 20 | import java.util.HashMap; |
| 21 | +import java.util.List; |
20 | 22 | import java.util.Map; |
21 | | - |
| 23 | +import com.intellijava.core.model.CohereLanguageResponse; |
22 | 24 | import com.intellijava.core.model.OpenaiLanguageResponse; |
| 25 | +import com.intellijava.core.model.SupportedLangModels; |
23 | 26 | import com.intellijava.core.model.input.LanguageModelInput; |
| 27 | +import com.intellijava.core.wrappers.CohereAIWrapper; |
24 | 28 | import com.intellijava.core.wrappers.OpenAIWrapper; |
25 | 29 |
|
26 | 30 | /** |
27 | | - * A class to call the most sophisticated remote language models. |
28 | | - * |
29 | | - * This class provides an API for interacting with OpenAI's GPT-3 language model. It is designed to be easily extensible |
30 | | - * to support other models in the future. |
| 31 | + * RemoteLanguageModel class to call the most sophisticated remote language |
| 32 | + * models. |
| 33 | + * |
| 34 | + * This class support: - Openai: - url: openai.com - description: provides an |
| 35 | + * API for interacting with OpenAI's GPT-3 language model. - model names : |
| 36 | + * text-davinci-003, text-curie-001, text-babbage-001, more. |
| 37 | + * |
| 38 | + * - cohere: - url: cohere.ai - description: provides an API for interacting |
| 39 | + * with generate language model. - it is recommended to fine tune your model or |
| 40 | + * add example of the response in the prompt when calling cohere models. - model |
| 41 | + * names : medium or xlarge |
31 | 42 | * |
32 | 43 | * @author github.com/Barqawiz |
33 | 44 | * |
34 | 45 | */ |
35 | 46 | public class RemoteLanguageModel { |
36 | | - |
37 | | - private String keyType; |
| 47 | + |
| 48 | + private SupportedLangModels keyType; |
38 | 49 | private OpenAIWrapper openaiWrapper; |
39 | | - |
| 50 | + private CohereAIWrapper cohereWrapper; |
| 51 | + |
40 | 52 | /** |
41 | | - * Constructor for the RemoteLanguageModel class. |
42 | | - * |
43 | | - * Creates an instance of the class and sets up the API key and the key type. |
44 | | - * Currently, only the "openai" key type is supported. |
45 | | - * |
46 | | - * @param keyValue the API key. |
47 | | - * @param keyType support openai only. |
48 | | - * |
49 | | - * @throws IllegalArgumentException if the keyType passed is not "openai". |
50 | | - * |
51 | | - */ |
52 | | - public RemoteLanguageModel(String keyValue, String keyType) { |
53 | | - |
54 | | - if (keyType.isEmpty() || keyType.equals("openai")) { |
55 | | - this.keyType = "openai"; |
56 | | - openaiWrapper = new OpenAIWrapper(keyValue); |
| 53 | + * Constructor for the RemoteLanguageModel class. |
| 54 | + * |
| 55 | + * Creates an instance of the class and sets up the key and the API type. |
| 56 | + * |
| 57 | + * @param keyValue the API key. |
| 58 | + * @param keyTypeString either openai (default) or cohere or send empty string |
| 59 | + * for default value. |
| 60 | + * |
| 61 | + * @throws IllegalArgumentException if the keyType passed is not "openai". |
| 62 | + * |
| 63 | + */ |
| 64 | + public RemoteLanguageModel(String keyValue, String keyTypeString) { |
| 65 | + |
| 66 | + if (keyTypeString.isEmpty()) { |
| 67 | + keyTypeString = SupportedLangModels.openai.toString(); |
| 68 | + } |
| 69 | + |
| 70 | + List<String> supportedModels = this.getSupportedModels(); |
| 71 | + |
| 72 | + if (supportedModels.contains(keyTypeString)) { |
| 73 | + this.initiate(keyValue, SupportedLangModels.valueOf(keyTypeString)); |
57 | 74 | } else { |
58 | | - throw new IllegalArgumentException("This version support openai keyType only"); |
| 75 | + String models = String.join(" - ", supportedModels); |
| 76 | + throw new IllegalArgumentException("The received keyValue not supported. Send any model from: " + models); |
59 | 77 | } |
60 | 78 | } |
61 | | - |
| 79 | + |
| 80 | + /** |
| 81 | + * Constructor for the RemoteLanguageModel class. |
| 82 | + * |
| 83 | + * Creates an instance of the class and sets up the API key and the enum key |
| 84 | + * type. |
| 85 | + * |
| 86 | + * @param keyValue the API key. |
| 87 | + * @param keyType enum version from the key type (SupportedModels). |
| 88 | + * |
| 89 | + * @throws IllegalArgumentException if the keyType passed is not "openai". |
| 90 | + * |
| 91 | + */ |
| 92 | + public RemoteLanguageModel(String keyValue, SupportedLangModels keyType) { |
| 93 | + this.initiate(keyValue, keyType); |
| 94 | + } |
| 95 | + |
| 96 | + /** |
| 97 | + * Get the supported models names as array of string |
| 98 | + * |
| 99 | + * @return supportedModels |
| 100 | + */ |
| 101 | + public List<String> getSupportedModels() { |
| 102 | + SupportedLangModels[] values = SupportedLangModels.values(); |
| 103 | + List<String> enumValues = new ArrayList<>(); |
| 104 | + |
| 105 | + for (int i = 0; i < values.length; i++) { |
| 106 | + enumValues.add(values[i].name()); |
| 107 | + } |
| 108 | + |
| 109 | + return enumValues; |
| 110 | + } |
62 | 111 |
|
| 112 | + /** |
| 113 | + * Common function to initiate the class from any constructor. |
| 114 | + * |
| 115 | + * @param keyValue the API key. |
| 116 | + * @param keyType enum version from the key type (SupportedModels). |
| 117 | + */ |
| 118 | + private void initiate(String keyValue, SupportedLangModels keyType) { |
| 119 | + // set the model type |
| 120 | + this.keyType = keyType; |
| 121 | + |
| 122 | + // generate the related model |
| 123 | + if (keyType.equals(SupportedLangModels.openai)) { |
| 124 | + this.openaiWrapper = new OpenAIWrapper(keyValue); |
| 125 | + } else if (keyType.equals(SupportedLangModels.cohere)) { |
| 126 | + this.cohereWrapper = new CohereAIWrapper(keyValue); |
| 127 | + } |
| 128 | + } |
| 129 | + |
63 | 130 | /** |
64 | 131 | * |
65 | 132 | * Call a remote large model to generate any text based on the received prompt. |
66 | 133 | * |
67 | 134 | * @param langInput flexible builder for language model parameters. |
| 135 | + * |
68 | 136 | * @return string for the model response. |
69 | | - * @throws IOException if there is an error when connecting to the OpenAI API. |
70 | | - * @throws IllegalArgumentException if the keyType passed in the constructor is not "openai". |
| 137 | + * @throws IOException if there is an error when connecting to the |
| 138 | + * OpenAI API. |
| 139 | + * @throws IllegalArgumentException if the keyType passed in the constructor is |
| 140 | + * not "openai". |
71 | 141 | * |
72 | 142 | */ |
73 | | - public String generateText(LanguageModelInput langInput) throws IOException { |
74 | | - |
75 | | - if (this.keyType.equals("openai")) { |
76 | | - return this.generateOpenaiText(langInput.getModel(), langInput.getPrompt(), |
77 | | - langInput.getTemperature(), langInput.getMaxTokens()); |
| 143 | + public String generateText(LanguageModelInput langInput) throws IOException { |
| 144 | + |
| 145 | + if (this.keyType.equals(SupportedLangModels.openai)) { |
| 146 | + return this.generateOpenaiText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(), |
| 147 | + langInput.getMaxTokens()); |
| 148 | + } else if (this.keyType.equals(SupportedLangModels.cohere)) { |
| 149 | + return this.generateCohereText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(), |
| 150 | + langInput.getMaxTokens()); |
78 | 151 | } else { |
79 | 152 | throw new IllegalArgumentException("This version support openai keyType only"); |
80 | 153 | } |
81 | | - |
| 154 | + |
82 | 155 | } |
83 | 156 |
|
84 | 157 | /** |
85 | | - * Private helper method for generating text from OpenAI GPT-3 model. |
86 | | - * |
87 | | - * @param model the model name, example: text-davinci-002. For more details about GPT-3 models, see: https://beta.openai.com/docs/models/gpt-3 |
88 | | - * @param prompt text of the required action or the question. |
89 | | - * @param temperature higher values means more risks and creativity. |
90 | | - * @param maxTokens maximum size of the model input and output. |
91 | | - * @return string model response. |
92 | | - * @throws IOException if there is an error when connecting to the OpenAI API. |
93 | | - * |
94 | | - */ |
95 | | - private String generateOpenaiText(String model, String prompt, float temperature, int maxTokens) throws IOException { |
96 | | - |
| 158 | + * Private helper method for generating text from OpenAI GPT-3 model. |
| 159 | + * |
| 160 | + * @param model the model name, example: text-davinci-003. For more |
| 161 | + * details about GPT-3 models, see: |
| 162 | + * https://beta.openai.com/docs/models/gpt-3 |
| 163 | + * @param prompt text of the required action or the question. |
| 164 | + * @param temperature higher values means more risks and creativity. |
| 165 | + * @param maxTokens maximum size of the model input and output. |
| 166 | + * @return string model response. |
| 167 | + * @throws IOException if there is an error when connecting to the OpenAI API. |
| 168 | + * |
| 169 | + */ |
| 170 | + private String generateOpenaiText(String model, String prompt, float temperature, int maxTokens) |
| 171 | + throws IOException { |
| 172 | + |
| 173 | + if (model.equals("")) |
| 174 | + model = "text-davinci-003"; |
| 175 | + |
97 | 176 | Map<String, Object> params = new HashMap<>(); |
98 | | - params.put("model", model); |
99 | | - params.put("prompt", prompt); |
100 | | - params.put("temperature", temperature); |
101 | | - params.put("max_tokens", maxTokens); |
102 | | - |
| 177 | + params.put("model", model); |
| 178 | + params.put("prompt", prompt); |
| 179 | + params.put("temperature", temperature); |
| 180 | + params.put("max_tokens", maxTokens); |
| 181 | + |
103 | 182 | OpenaiLanguageResponse resModel = (OpenaiLanguageResponse) openaiWrapper.generateText(params); |
104 | | - |
| 183 | + |
105 | 184 | return resModel.getChoices().get(0).getText(); |
106 | | - |
| 185 | + |
| 186 | + } |
| 187 | + |
| 188 | + /** |
| 189 | + * Private helper method for generating text from Cohere model. |
| 190 | + * |
| 191 | + * @param model the model name, either medium or xlarge. |
| 192 | + * @param prompt text of the required action or the question. |
| 193 | + * @param temperature higher values means more risks and creativity. |
| 194 | + * @param maxTokens maximum size of the model input and output. |
| 195 | + * @return string model response. |
| 196 | + * @throws IOException if there is an error when connecting to the API. |
| 197 | + * |
| 198 | + */ |
| 199 | + private String generateCohereText(String model, String prompt, float temperature, int maxTokens) |
| 200 | + throws IOException { |
| 201 | + |
| 202 | + if (model.equals("")) |
| 203 | + model = "xlarge"; |
| 204 | + |
| 205 | + Map<String, Object> params = new HashMap<>(); |
| 206 | + params.put("model", model); |
| 207 | + params.put("prompt", prompt); |
| 208 | + params.put("temperature", temperature); |
| 209 | + params.put("max_tokens", maxTokens); |
| 210 | + |
| 211 | + CohereLanguageResponse resModel = (CohereLanguageResponse) cohereWrapper.generateText(params); |
| 212 | + |
| 213 | + return resModel.getGenerations().get(0).getText(); |
| 214 | + |
107 | 215 | } |
108 | 216 | } |
0 commit comments