Skip to content

Commit d08b08d

Browse files
authored
wrap base config in PaLMConfig (#15)
1 parent 2df85ba commit d08b08d

File tree

6 files changed

+63
-21
lines changed

6 files changed

+63
-21
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,51 @@
11
package org.llm4j.palm;
22

3+
import org.apache.commons.configuration2.Configuration;
4+
35
public class PaLMConfig {
46
public static String MODEL_ID_KEY = "palm.modelId";
7+
public static String CHAT_MODEL_ID = "models/chat-bison-001";
8+
public static String EMBED_MODEL_ID = "models/embedding-gecko-001";
9+
public static String TEXT_MODEL_ID = "models/text-bison-001";
10+
11+
private final Configuration configs;
12+
public PaLMConfig(Configuration configs) {
13+
this.configs = configs;
14+
}
15+
16+
public String getChatModelId() {
17+
return configs.getString(PaLMConfig.MODEL_ID_KEY, CHAT_MODEL_ID);
18+
}
19+
20+
public String getEmbedModelId() {
21+
return configs.getString(PaLMConfig.MODEL_ID_KEY, EMBED_MODEL_ID);
22+
}
23+
24+
public String getTextModelId() {
25+
return configs.getString(PaLMConfig.MODEL_ID_KEY, TEXT_MODEL_ID);
26+
}
27+
28+
public Integer getTopK() {
29+
return configs.getInteger("topK", 1);
30+
}
31+
32+
public Double getTopP() {
33+
return configs.getDouble("topP", 0.9);
34+
}
35+
36+
public Double getTemperature() {
37+
return configs.getDouble("temperature", 1.0);
38+
}
39+
40+
public Integer getMaxNewTokens() {
41+
return configs.getInteger("maxNewTokens", 10);
42+
}
43+
44+
public Integer getMaxOutputTokens() {
45+
return configs.getInteger("maxOutputTokens", 1000);
46+
}
47+
48+
public Integer getCandidateCount() {
49+
return configs.getInteger("candidateCount", 1);
50+
}
551
}

llm4j-palm/src/main/java/org/llm4j/palm/PaLMLanguageModel.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
public class PaLMLanguageModel implements LanguageModel {
1616

17-
private Configuration config;
17+
private PaLMConfig config;
1818
private PaLMClient client;
1919

2020
public PaLMLanguageModel(Builder builder) {
@@ -64,11 +64,11 @@ public List<Float> embed(String text) {
6464
}
6565

6666
public static final class Builder implements LanguageModelFactory {
67-
private Configuration config;
67+
private PaLMConfig config;
6868
private PaLMClient client;
6969

7070
public LanguageModel getLanguageModel(Configuration config) {
71-
this.config = config;
71+
this.config = new PaLMConfig(config);
7272
this.client = new PaLMClient.Builder().withConfig(config).build();
7373
return new PaLMLanguageModel(this);
7474
}

llm4j-palm/src/main/java/org/llm4j/palm/request/EmbedRequestFactory.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import org.llm4j.palm.PaLMConfig;
66

77
public class EmbedRequestFactory {
8-
static String EMBED_MODEL_ID = "models/embedding-gecko-001";
98

109
/**
1110
* which model to use to generate the result
@@ -19,8 +18,8 @@ public EmbedRequestFactory withText(String text) {
1918
return this;
2019
}
2120

22-
public EmbedRequestFactory withConfig(Configuration configs) {
23-
this.modelId = configs.getString(PaLMConfig.MODEL_ID_KEY, EMBED_MODEL_ID);
21+
public EmbedRequestFactory withConfig(PaLMConfig configs) {
22+
this.modelId = configs.getEmbedModelId();
2423
return this;
2524
}
2625

llm4j-palm/src/main/java/org/llm4j/palm/request/MessageRequestFactory.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
public class MessageRequestFactory {
1111

12-
static String CHAT_MODEL_ID = "models/chat-bison-001";
13-
1412
/**
1513
* which model to use to generate the result
1614
*/
@@ -56,8 +54,8 @@ public MessageRequestFactory withContext(String context) {
5654
return this;
5755
}
5856

59-
public MessageRequestFactory withConfig(Configuration configs) {
60-
this.modelId = configs.getString(PaLMConfig.MODEL_ID_KEY, CHAT_MODEL_ID);
57+
public MessageRequestFactory withConfig(PaLMConfig configs) {
58+
this.modelId = configs.getChatModelId();
6159
this.parameters = PaLMRequestParameters.builder()
6260
.withConfig(configs)
6361
.build();

llm4j-palm/src/main/java/org/llm4j/palm/request/PaLMRequestParameters.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.llm4j.palm.request;
22

33
import org.apache.commons.configuration2.Configuration;
4+
import org.llm4j.palm.PaLMConfig;
45

56
public class PaLMRequestParameters {
67

@@ -42,13 +43,13 @@ public static final class Builder {
4243
private Integer maxOutputTokens;
4344
private Integer candidateCount;
4445

45-
public Builder withConfig(Configuration configs) {
46-
this.topK = configs.getInteger("topK", 1);
47-
this.topP = configs.getDouble("topP", 0.9);
48-
this.temperature = configs.getDouble("temperature", 1.0);
49-
this.maxNewTokens = configs.getInteger("maxNewTokens", 10);
50-
this.maxOutputTokens = configs.getInteger("maxOutputTokens", 1000);
51-
this.candidateCount = configs.getInteger("candidateCount", 1);
46+
public Builder withConfig(PaLMConfig configs) {
47+
this.topK = configs.getTopK();
48+
this.topP = configs.getTopP();
49+
this.temperature = configs.getTemperature();
50+
this.maxNewTokens = configs.getMaxNewTokens();
51+
this.maxOutputTokens = configs.getMaxOutputTokens();
52+
this.candidateCount = configs.getCandidateCount();
5253
return this;
5354
}
5455
public PaLMRequestParameters build() {

llm4j-palm/src/main/java/org/llm4j/palm/request/TextRequestFactory.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22

33
import com.google.ai.generativelanguage.v1beta2.GenerateTextRequest;
44
import com.google.ai.generativelanguage.v1beta2.TextPrompt;
5-
import org.apache.commons.configuration2.Configuration;
65
import org.llm4j.palm.PaLMConfig;
76

87
public class TextRequestFactory {
9-
static String TEXT_MODEL_ID = "models/text-bison-001";
108

119
/**
1210
* which model to use to generate the result
@@ -22,8 +20,8 @@ public TextRequestFactory withInputs(String inputs) {
2220
return this;
2321
}
2422

25-
public TextRequestFactory withConfig(Configuration configs) {
26-
this.modelId = configs.getString(PaLMConfig.MODEL_ID_KEY, TEXT_MODEL_ID);
23+
public TextRequestFactory withConfig(PaLMConfig configs) {
24+
this.modelId = configs.getTextModelId();
2725
this.parameters = PaLMRequestParameters.builder()
2826
.withConfig(configs)
2927
.build();

0 commit comments

Comments
 (0)