Skip to content

Commit 66afa2c

Browse files
authored
Merge pull request #5 from Barqawiz/support-multi-output
Support multi output
2 parents 51ffd9e + d56078b commit 66afa2c

File tree

15 files changed

+176
-33
lines changed

15 files changed

+176
-33
lines changed

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Intelligent Java
2-
*IntelliJava V0.6.0*
2+
*IntelliJava V0.6.1*
33

44
Intelligent java (IntelliJava) is the ultimate tool for Java developers looking to integrate with the latest language models and deep learning frameworks. The library provides a simple and intuitive API with convenient methods for sending text input to models like GPT-3 and DALL·E, and receiving generated text or images in return. With just a few lines of code, you can easily access the power of cutting-edge AI models to enhance your projects.
55

@@ -14,32 +14,32 @@ The supported models:
1414
3. Call the ``RemoteLanguageModel`` for the language models and ``RemoateImageModel`` for image generation.
1515

1616
## Integration
17-
The package released to [Maven Central Repository](https://central.sonatype.dev/artifact/io.github.barqawiz/intellijava.core/0.6.0).
17+
The package released to [Maven Central Repository](https://central.sonatype.com/artifact/io.github.barqawiz/intellijava.core/0.6.2).
1818

1919
Maven:
2020
```xml
2121
<dependency>
2222
<groupId>io.github.barqawiz</groupId>
2323
<artifactId>intellijava.core</artifactId>
24-
<version>0.6.0</version>
24+
<version>0.6.2</version>
2525
</dependency>
2626
```
2727

2828
Gradle:
2929

3030
```
31-
implementation group: 'io.github.barqawiz', name: 'intellijava.core', version: '0.6.0'
31+
implementation 'io.github.barqawiz:intellijava.core:0.6.2'
3232
```
3333

3434
Gradle(Kotlin):
3535
```
36-
implementation("io.github.barqawiz:intellijava.core:0.6.0")
36+
implementation("io.github.barqawiz:intellijava.core:0.6.2")
3737
```
3838

3939
Jar download:
40-
[intellijava.jar](https://repo1.maven.org/maven2/io/github/barqawiz/intellijava.core/0.6.0/intellijava.core-0.6.0.jar).
40+
[intellijava.jar](https://repo1.maven.org/maven2/io/github/barqawiz/intellijava.core/0.6.2/intellijava.core-0.6.2.jar).
4141

42-
For ready integration: try the sample_code.
42+
For ready integration: try the [sample_code](https://github.com/Barqawiz/IntelliJava/tree/main/sample_code).
4343

4444
## Code Example
4545
**Language model code** (2 steps):
@@ -72,6 +72,7 @@ For full example check the code inside sample_code project.
7272

7373
## Third-party dependencies
7474
The only dependencies is **GSON**.
75+
*Required to add manually when using IntelliJava jar. However, if you imported this repo through Maven, it will handle the dependencies.*
7576

7677
For Maven:
7778
```

core/com.intellijava.core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>io.github.barqawiz</groupId>
88
<artifactId>intellijava.core</artifactId>
9-
<version>0.6.0</version>
9+
<version>0.6.2</version>
1010

1111
<name>Intellijava</name>
1212
<description>IntelliJava allows java developers to easily integrate with the latest language models, image generation, and deep learning frameworks.</description>

core/com.intellijava.core/src/main/java/com/intellijava/core/controller/RemoteLanguageModel.java

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
import java.util.List;
2222
import java.util.Map;
2323
import com.intellijava.core.model.CohereLanguageResponse;
24+
import com.intellijava.core.model.CohereLanguageResponse.Generation;
2425
import com.intellijava.core.model.OpenaiLanguageResponse;
26+
import com.intellijava.core.model.OpenaiLanguageResponse.Choice;
2527
import com.intellijava.core.model.SupportedLangModels;
28+
import com.intellijava.core.model.OpenaiImageResponse.Data;
2629
import com.intellijava.core.model.input.LanguageModelInput;
2730
import com.intellijava.core.wrappers.CohereAIWrapper;
2831
import com.intellijava.core.wrappers.OpenAIWrapper;
@@ -131,8 +134,10 @@ private void initiate(String keyValue, SupportedLangModels keyType) {
131134
*
132135
* Call a remote large model to generate any text based on the received prompt.
133136
*
134-
* @param langInput flexible builder for language model parameters.
137+
* To support multiple response call the variation function generateMultiText.
135138
*
139+
* @param langInput flexible builder for language model parameters.
140+
*
136141
* @return string for the model response.
137142
* @throws IOException if there is an error when connecting to the
138143
* OpenAI API.
@@ -143,11 +148,42 @@ private void initiate(String keyValue, SupportedLangModels keyType) {
143148
public String generateText(LanguageModelInput langInput) throws IOException {
144149

145150
if (this.keyType.equals(SupportedLangModels.openai)) {
146-
return this.generateOpenaiText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(),
147-
langInput.getMaxTokens());
151+
return this.generateOpenaiText(langInput.getModel(),
152+
langInput.getPrompt(), langInput.getTemperature(),
153+
langInput.getMaxTokens(), langInput.getNumberOfOutputs()).get(0);
154+
} else if (this.keyType.equals(SupportedLangModels.cohere)) {
155+
return this.generateCohereText(langInput.getModel(),
156+
langInput.getPrompt(), langInput.getTemperature(),
157+
langInput.getMaxTokens(), langInput.getNumberOfOutputs()).get(0);
158+
} else {
159+
throw new IllegalArgumentException("This version support openai keyType only");
160+
}
161+
162+
}
163+
164+
/**
165+
*
166+
* Call a remote large model to generate any text based on the received prompt.
167+
*
168+
* @param langInput flexible builder for language model parameters.
169+
*
170+
* @return list of model responses.
171+
* @throws IOException if there is an error when connecting to the
172+
* OpenAI API.
173+
* @throws IllegalArgumentException if the keyType passed in the constructor is
174+
* not "openai".
175+
*
176+
*/
177+
public List<String> generateMultiText(LanguageModelInput langInput) throws IOException {
178+
179+
if (this.keyType.equals(SupportedLangModels.openai)) {
180+
return this.generateOpenaiText(langInput.getModel(),
181+
langInput.getPrompt(), langInput.getTemperature(),
182+
langInput.getMaxTokens(), langInput.getNumberOfOutputs());
148183
} else if (this.keyType.equals(SupportedLangModels.cohere)) {
149-
return this.generateCohereText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(),
150-
langInput.getMaxTokens());
184+
return this.generateCohereText(langInput.getModel(),
185+
langInput.getPrompt(), langInput.getTemperature(),
186+
langInput.getMaxTokens(), langInput.getNumberOfOutputs());
151187
} else {
152188
throw new IllegalArgumentException("This version support openai keyType only");
153189
}
@@ -163,11 +199,13 @@ public String generateText(LanguageModelInput langInput) throws IOException {
163199
* @param prompt text of the required action or the question.
164200
* @param temperature higher values means more risks and creativity.
165201
* @param maxTokens maximum size of the model input and output.
202+
* @param numberOfOutputs number of model outputs.
166203
* @return string model response.
167204
* @throws IOException if there is an error when connecting to the OpenAI API.
168205
*
169206
*/
170-
private String generateOpenaiText(String model, String prompt, float temperature, int maxTokens)
207+
private List<String> generateOpenaiText(String model, String prompt, float temperature,
208+
int maxTokens, int numberOfOutputs)
171209
throws IOException {
172210

173211
if (model.equals(""))
@@ -178,10 +216,16 @@ private String generateOpenaiText(String model, String prompt, float temperature
178216
params.put("prompt", prompt);
179217
params.put("temperature", temperature);
180218
params.put("max_tokens", maxTokens);
219+
params.put("n", numberOfOutputs);
181220

182221
OpenaiLanguageResponse resModel = (OpenaiLanguageResponse) openaiWrapper.generateText(params);
183222

184-
return resModel.getChoices().get(0).getText();
223+
List<String> outputs = new ArrayList<>();
224+
for (Choice item : resModel.getChoices()) {
225+
outputs.add(item.getText());
226+
}
227+
228+
return outputs;
185229

186230
}
187231

@@ -192,11 +236,13 @@ private String generateOpenaiText(String model, String prompt, float temperature
192236
* @param prompt text of the required action or the question.
193237
* @param temperature higher values means more risks and creativity.
194238
* @param maxTokens maximum size of the model input and output.
239+
* @param numberOfOutputs number of model outputs.
195240
* @return string model response.
196241
* @throws IOException if there is an error when connecting to the API.
197242
*
198243
*/
199-
private String generateCohereText(String model, String prompt, float temperature, int maxTokens)
244+
private List<String> generateCohereText(String model, String prompt, float temperature,
245+
int maxTokens, int numberOfOutputs)
200246
throws IOException {
201247

202248
if (model.equals(""))
@@ -207,10 +253,16 @@ private String generateCohereText(String model, String prompt, float temperature
207253
params.put("prompt", prompt);
208254
params.put("temperature", temperature);
209255
params.put("max_tokens", maxTokens);
256+
params.put("num_generations", numberOfOutputs);
210257

211258
CohereLanguageResponse resModel = (CohereLanguageResponse) cohereWrapper.generateText(params);
212-
213-
return resModel.getGenerations().get(0).getText();
259+
260+
List<String> outputs = new ArrayList<>();
261+
for (Generation item: resModel.getGenerations()) {
262+
outputs.add(item.getText());
263+
}
264+
265+
return outputs;
214266

215267
}
216268
}

core/com.intellijava.core/src/main/java/com/intellijava/core/model/CohereLanguageResponse.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ public class CohereLanguageResponse extends BaseRemoteModel{
1515
private List<Generation> generations;
1616
private String prompt;
1717

18+
/**
19+
* CohereLanguageResponse default constructor.
20+
*/
21+
public CohereLanguageResponse() {
22+
23+
}
24+
1825
/**
1926
*
2027
* Generation is wrapper for the response
@@ -26,6 +33,13 @@ public static class Generation {
2633
private String id;
2734
private String text;
2835

36+
/**
37+
* Generation default constructor.
38+
*/
39+
public Generation() {
40+
41+
}
42+
2943
/**
3044
* Get the unique identifier for the generation.
3145
*

core/com.intellijava.core/src/main/java/com/intellijava/core/model/input/LanguageModelInput.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ public class LanguageModelInput {
1717
private String prompt;
1818
private float temperature;
1919
private int maxTokens;
20+
private int numberOfOutputs = 1;
2021

2122
/**
2223
* Private Constructor for the Builder.
@@ -27,6 +28,7 @@ private LanguageModelInput(Builder builder) {
2728
this.prompt = builder.prompt;
2829
this.temperature = builder.temperature;
2930
this.maxTokens = builder.maxTokens;
31+
this.numberOfOutputs = builder.numberOfOutputs;
3032
}
3133
/**
3234
*
@@ -38,6 +40,7 @@ public static class Builder {
3840
private String prompt;
3941
private float temperature;
4042
private int maxTokens;
43+
private int numberOfOutputs = 1;
4144

4245
/**
4346
* Language input Constructor.
@@ -90,6 +93,22 @@ public Builder setMaxTokens(int maxTokens) {
9093
this.maxTokens = maxTokens;
9194
return this;
9295
}
96+
97+
/**
98+
* Setter for numberOfOutputs
99+
* @param numberOfOutputs number of model outputs, default value is 1.
100+
*
101+
* Cohere maximum value is five.
102+
*
103+
* @return instance of Builder
104+
*/
105+
public Builder setNumberOfOutputs(int numberOfOutputs) {
106+
if (this.numberOfOutputs < 0)
107+
this.numberOfOutputs = 0;
108+
109+
this.numberOfOutputs = numberOfOutputs;
110+
return this;
111+
}
93112

94113
/**
95114
* Build the final LanguageModelInput object.
@@ -130,5 +149,15 @@ public float getTemperature() {
130149
public int getMaxTokens() {
131150
return maxTokens;
132151
}
152+
153+
/**
154+
* Getter for number of model outputs.
155+
* @return numberOfOutputs
156+
*/
157+
public int getNumberOfOutputs() {
158+
return numberOfOutputs;
159+
}
160+
161+
133162
}
134163

core/com.intellijava.core/src/main/java/com/intellijava/core/utils/ConnHelper.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,5 @@ public static String readStream(InputStream stream) throws IOException {
108108
}
109109
}
110110
return result.toString();
111-
}
111+
}
112112
}

core/com.intellijava.core/src/main/java/com/intellijava/core/wrappers/CohereAIWrapper.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
/**
1717
*
18+
* Wrapper for the Cohere API models.
1819
*
1920
* @author github.com/Barqawiz
2021
*

core/com.intellijava.core/src/main/java/com/intellijava/core/wrappers/OpenAIWrapper.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package com.intellijava.core.wrappers;
1717

18+
import java.io.BufferedReader;
1819
import java.io.IOException;
1920
import java.io.InputStreamReader;
2021
import java.io.OutputStream;
@@ -70,7 +71,7 @@ public BaseRemoteModel generateText(Map<String, Object> params) throws IOExcepti
7071
String url = API_BASE_URL + Config2.getInstance().getProperty("url.openai.completions");
7172

7273
String json = ConnHelper.convertMaptToJson(params);
73-
74+
7475
HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection();
7576
connection.setRequestMethod("POST");
7677
connection.setRequestProperty("Content-Type", "application/json");

core/com.intellijava.core/src/test/java/com/intellijava/core/CohereModelConnectionTest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
import com.intellijava.core.utils.Config2;
1616
import com.intellijava.core.wrappers.CohereAIWrapper;
1717

18+
/**
19+
*
20+
* Unit test for Remote Language Model
21+
*/
1822
public class CohereModelConnectionTest {
1923

2024
/**

core/com.intellijava.core/src/test/java/com/intellijava/core/OpenaiModelConnectionTest.java

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import com.intellijava.core.controller.RemoteLanguageModel;
2828
import com.intellijava.core.model.OpenaiImageResponse;
2929
import com.intellijava.core.model.OpenaiImageResponse.Data;
30+
import com.intellijava.core.model.SupportedLangModels;
3031
import com.intellijava.core.model.input.ImageModelInput;
3132
import com.intellijava.core.model.input.LanguageModelInput;
3233
import com.intellijava.core.utils.Config2;
@@ -51,10 +52,10 @@ public void testOpenaiCompletionRemoteModel() {
5152

5253
try {
5354

54-
RemoteLanguageModel wrapper = new RemoteLanguageModel(openaiKey, "openai");
55+
RemoteLanguageModel wrapper = new RemoteLanguageModel(openaiKey, SupportedLangModels.openai);
5556

5657
LanguageModelInput input = new LanguageModelInput.Builder("return a java code that print hello world")
57-
.setModel("text-davinci-002").setTemperature(0.7f).setMaxTokens(50).build();
58+
.setModel("text-davinci-003").setTemperature(0.7f).setMaxTokens(50).build();
5859

5960
if (openaiKey.isBlank()) return;
6061

@@ -75,6 +76,37 @@ public void testOpenaiCompletionRemoteModel() {
7576
}
7677
}
7778

79+
80+
@Test
81+
public void testOpenaiMultiTextCompletionRemoteModel() {
82+
83+
try {
84+
85+
RemoteLanguageModel wrapper = new RemoteLanguageModel(openaiKey, "openai");
86+
87+
LanguageModelInput input = new LanguageModelInput.Builder("Summarize the plot of the 'Inception' movie in two sentences")
88+
.setModel("text-davinci-003").setTemperature(0.7f)
89+
.setMaxTokens(80).setNumberOfOutputs(2).build();
90+
91+
if (openaiKey.isBlank()) return;
92+
93+
List<String> resValues = wrapper.generateMultiText(input);
94+
95+
for (String result : resValues)
96+
System.out.print("- " + result);
97+
98+
assert resValues.size() == 2;
99+
100+
} catch (IOException e) {
101+
if (openaiKey.isBlank()) {
102+
System.out.print("testOpenaiCompletion: set the API key to run the test case.");
103+
} else {
104+
fail("Test case failed with exception: " + e.getMessage());
105+
}
106+
107+
}
108+
}
109+
78110
@Test
79111
public void testImageWrapper() {
80112

0 commit comments

Comments
 (0)