Skip to content

Commit f9c293f

Browse files
authored
cohere llm (#12)
* cohere llm * cohere example
1 parent 22d2ed5 commit f9c293f

File tree

10 files changed

+278
-0
lines changed

10 files changed

+278
-0
lines changed

llm4j-cohere/pom.xml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
3+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
4+
5+
<modelVersion>4.0.0</modelVersion>
6+
7+
<parent>
8+
<groupId>org.llm4j</groupId>
9+
<artifactId>llm4j-parent</artifactId>
10+
<version>0.0-SNAPSHOT</version>
11+
<relativePath>../parent-pom.xml</relativePath>
12+
</parent>
13+
14+
<artifactId>llm4j-cohere</artifactId>
15+
16+
<packaging>jar</packaging>
17+
<name>LLM4J Cohere</name>
18+
<description>The LLM4J API implementation for Cohere</description>
19+
20+
<url>http://github.com/dzlab</url>
21+
22+
<properties>
23+
<module-name>org.llm4j.palm</module-name>
24+
<llm4j.provider.implementation>org.llm4j.cohere.CohereServiceProvider</llm4j.provider.implementation>
25+
<llm4j.provider.type>cohere</llm4j.provider.type>
26+
<cohere.version>0.1</cohere.version>
27+
</properties>
28+
29+
<dependencies>
30+
<dependency>
31+
<groupId>org.llm4j</groupId>
32+
<artifactId>llm4j-api</artifactId>
33+
</dependency>
34+
35+
<dependency>
36+
<groupId>com.github.llmjava</groupId>
37+
<artifactId>cohere4j</artifactId>
38+
<version>${cohere.version}</version>
39+
</dependency>
40+
</dependencies>
41+
42+
<repositories>
43+
<repository>
44+
<id>jitpack.io</id>
45+
<url>https://jitpack.io</url>
46+
</repository>
47+
</repositories>
48+
49+
</project>
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package org.llm4j.cohere;
2+
3+
import com.github.llmjava.cohere4j.CohereClient;
4+
import com.github.llmjava.cohere4j.CohereConfig;
5+
import com.github.llmjava.cohere4j.request.EmbedRequest;
6+
import com.github.llmjava.cohere4j.request.GenerateRequest;
7+
import com.github.llmjava.cohere4j.response.EmbedResponse;
8+
import com.github.llmjava.cohere4j.response.GenerateResponse;
9+
import org.apache.commons.configuration2.Configuration;
10+
import org.llm4j.api.ChatHistory;
11+
import org.llm4j.api.LanguageModel;
12+
import org.llm4j.api.LanguageModelFactory;
13+
14+
import java.util.ArrayList;
15+
import java.util.Arrays;
16+
import java.util.List;
17+
import java.util.Map;
18+
19+
public class CohereLanguageModel implements LanguageModel {
20+
private CohereConfig config;
21+
22+
private CohereClient client;
23+
CohereLanguageModel(Builder builder) {
24+
this.client = builder.client;
25+
this.config = builder.config;
26+
}
27+
28+
@Override
29+
public String process(String text) {
30+
GenerateRequest request = new GenerateRequest.Builder()
31+
.withConfig(config)
32+
.withPrompt(text)
33+
.build();
34+
GenerateResponse response = client.generate(request);
35+
return response.getTexts().get(0);
36+
}
37+
38+
@Override
39+
public String process(ChatHistory history) {
40+
List<String> lines = new ArrayList<>();
41+
// Add context
42+
if(history.getContext()!=null) lines.add(history.getContext());
43+
// Add examples
44+
for(Map.Entry<ChatHistory.Message, ChatHistory.Message> pair: history.getExampleList()) {
45+
lines.add(pair.getKey().toString());
46+
lines.add(pair.getValue().toString());
47+
}
48+
// Add conversations
49+
for(ChatHistory.Message message: history.getMessageList()) {
50+
lines.add(message.toString());
51+
}
52+
// submit
53+
String text = String.join("\n", lines);
54+
55+
return process(text);
56+
}
57+
58+
@Override
59+
public List<Float> embed(String text) {
60+
EmbedRequest request = new EmbedRequest.Builder()
61+
.withText(text)
62+
.withModel(config.getModel())
63+
.build();
64+
EmbedResponse response = client.embed(request);
65+
List<Float> embeddings = Arrays.asList(response.getEmbeddings(0));
66+
return embeddings;
67+
}
68+
69+
70+
public static final class Builder implements LanguageModelFactory {
71+
private CohereConfig config;
72+
private CohereClient client;
73+
74+
public LanguageModel getLanguageModel(Configuration config) {
75+
this.config = new CohereConfig(config);
76+
this.client = new CohereClient.Builder().withConfig(this.config).build();
77+
return new CohereLanguageModel(this);
78+
}
79+
}
80+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package org.llm4j.cohere;
2+
3+
import org.llm4j.api.LanguageModelFactory;
4+
import org.llm4j.spi.LLM4JServiceProvider;
5+
6+
public class CohereServiceProvider implements LLM4JServiceProvider {
7+
8+
public LanguageModelFactory getLLMFactory() {
9+
return new CohereLanguageModel.Builder();
10+
}
11+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
org.llm4j.cohere.CohereServiceProvider
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package org.llm4j.cohere;
2+
3+
import org.apache.commons.configuration2.Configuration;
4+
import org.apache.commons.configuration2.EnvironmentConfiguration;
5+
import org.apache.commons.configuration2.MapConfiguration;
6+
import org.apache.commons.configuration2.builder.fluent.Configurations;
7+
import org.apache.commons.configuration2.ex.ConfigurationException;
8+
import org.junit.jupiter.api.Disabled;
9+
import org.junit.jupiter.api.DisplayName;
10+
import org.junit.jupiter.api.Test;
11+
import org.junit.Ignore;
12+
import org.llm4j.api.ChatHistory;
13+
import org.llm4j.api.LanguageModel;
14+
15+
import java.util.HashMap;
16+
import java.util.Map;
17+
import java.util.List;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static com.google.common.truth.Truth.assertWithMessage;
21+
22+
@Disabled("requires API key to run")
23+
public class CohereLanguageModelTest {
24+
25+
@Test
26+
@DisplayName("Should generate text")
27+
public void should_process_text_generation_request() throws ConfigurationException {
28+
29+
Configuration config = new Configurations().properties("llm4j.properties");
30+
31+
LanguageModel llm = new CohereLanguageModel.Builder()
32+
.getLanguageModel(config);
33+
34+
String answer = llm.process("In what country is El Outed located?");
35+
36+
assertWithMessage("Answer should contain right answer").
37+
that(answer.toLowerCase()).contains("algeria");
38+
}
39+
40+
@Test
41+
@DisplayName("Should process chat request")
42+
public void should_process_chat_request() throws ConfigurationException {
43+
44+
Configuration config = new Configurations().properties("llm4j.properties");
45+
46+
LanguageModel llm = new CohereLanguageModel.Builder()
47+
.getLanguageModel(config);
48+
49+
ChatHistory history = new ChatHistory()
50+
.setContext("Respond to all questions with a rhyming poem.")
51+
.addExample(
52+
"What is the capital of Algeria?",
53+
"If the capital of Algeria is what you seek, Algiers is where you ought to peek.")
54+
.addMessage("How tall is Makam Echahid?");
55+
56+
String answer = llm.process(history);
57+
58+
assertWithMessage("Answer should contain right answer").
59+
that(answer.toLowerCase()).contains("algeria");
60+
}
61+
62+
63+
@Test
64+
@DisplayName("Should embed text")
65+
public void should_process_text_embed_request() throws ConfigurationException {
66+
Map<String, String> properties = new HashMap<>();
67+
properties.put("cohere.apiKey", "${env:COHERE_API_KEY}");
68+
properties.put("cohere.model", "embed-english-light-v2.0");
69+
MapConfiguration config = new MapConfiguration(properties);
70+
71+
LanguageModel llm = new CohereLanguageModel.Builder()
72+
.getLanguageModel(config);
73+
74+
List<Float> embeddings = llm.embed("In what country is El Outed located?");
75+
76+
assertThat(embeddings).isNotEmpty();
77+
}
78+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Set API key using env variable or use actual value
2+
cohere.apiKey=${env:COHERE_API_KEY}
3+
4+
#cohere.model=command
5+
6+
# timeout in milliseconds
7+
timeout=10000
8+
9+
topK=3
10+
topP=0.4
11+
temperature=0.7
12+
maxNewTokens=256
13+
maxOutputTokens=1024
14+
candidateCount=1

llm4j-examples/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
<groupId>org.llm4j</groupId>
3838
<artifactId>llm4j-palm</artifactId>
3939
</dependency>
40+
<dependency>
41+
<groupId>org.llm4j</groupId>
42+
<artifactId>llm4j-cohere</artifactId>
43+
</dependency>
4044

4145
</dependencies>
4246

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package org.llm4j.examples;
2+
3+
import org.apache.commons.configuration2.Configuration;
4+
import org.apache.commons.configuration2.builder.fluent.Configurations;
5+
import org.apache.commons.configuration2.ex.ConfigurationException;
6+
import org.llm4j.api.LLM4J;
7+
import org.llm4j.cohere.CohereLanguageModel;
8+
9+
public class CohereApp {
10+
11+
public static void main(String[] args) throws ConfigurationException {
12+
Configuration config = new Configurations().properties("cohere.properties");
13+
14+
CohereLanguageModel.Builder factory = new CohereLanguageModel.Builder();
15+
CohereLanguageModel llm = (CohereLanguageModel) LLM4J.getLanguageModel(config, factory);
16+
17+
String answer = llm.process("In what country is Andalossia?");
18+
System.out.println(answer);
19+
}
20+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Set API key using env variable or use actual value
2+
cohere.apiKey=${env:COHERE_API_KEY}
3+
4+
cohere.model=command
5+
6+
# timeout in milliseconds
7+
timeout=10000
8+
9+
topK=3
10+
topP=0.4
11+
temperature=0.7
12+
maxNewTokens=256
13+
maxOutputTokens=1024
14+
candidateCount=1

pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
<modules>
1616
<module>parent-pom.xml</module>
1717
<module>llm4j-api</module>
18+
<module>llm4j-cohere</module>
1819
<module>llm4j-examples</module>
1920
<module>llm4j-huggingface</module>
2021
<module>llm4j-palm</module>
@@ -30,6 +31,12 @@
3031
<version>${project.version}</version>
3132
</dependency>
3233

34+
<dependency>
35+
<groupId>org.llm4j</groupId>
36+
<artifactId>llm4j-cohere</artifactId>
37+
<version>${project.version}</version>
38+
</dependency>
39+
3340
<dependency>
3441
<groupId>org.llm4j</groupId>
3542
<artifactId>llm4j-huggingface</artifactId>

0 commit comments

Comments
 (0)