88import org .junit .Test ;
99import org .junit .Assert ;
1010
11-
1211public class LlamaChatModelTest {
13-
12+
1413 private static LlamaModel model ;
15-
14+
1615 @ BeforeClass
1716 public static void setup () {
18- // LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> System.out.println(level + ": " + msg));
1917 model = new LlamaModel (
20- new ModelParameters ()
18+ new ModelParameters ()
2119 .setCtxSize (128 )
2220 .setModel ("models/codellama-7b.Q2_K.gguf" )
23- //.setModelUrl("https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf")
2421 .setGpuLayers (43 )
25- .enableEmbedding ().enableLogTimestamps ().enableLogPrefix ()
26- );
22+ .enableLogTimestamps ()
23+ .enableLogPrefix ()
24+ );
2725 }
2826
2927 @ AfterClass
@@ -32,42 +30,91 @@ public static void tearDown() {
3230 model .close ();
3331 }
3432 }
33+
34+ @ Test
35+ public void testMultiTurnChat () {
36+ List <Pair <String , String >> userMessages = new ArrayList <>();
37+ userMessages .add (new Pair <>("user" , "Recommend a good ML book." ));
38+
39+ InferenceParameters params = new InferenceParameters ("" )
40+ .setMessages ("You are a Book Recommendation System" , userMessages )
41+ .setTemperature (0.7f )
42+ .setNPredict (50 );
43+
44+ String response1 = model .completeChat (params );
45+ Assert .assertNotNull (response1 );
46+
47+ userMessages .add (new Pair <>("assistant" , response1 ));
48+ userMessages .add (new Pair <>("user" , "How does it compare to 'Hands-on ML'?" ));
49+
50+ params .setMessages ("Book" , userMessages );
51+ String response2 = model .completeChat (params );
52+
53+ Assert .assertNotNull (response2 );
54+ Assert .assertNotEquals (response1 , response2 );
55+ }
3556
3657 @ Test
37- public void testChat () {
38- List <Pair <String , String >> userMessages = new ArrayList <>();
39- userMessages .add (new Pair <>("user" , "What is the best book for machine learning?" ));
40-
41- InferenceParameters params = new InferenceParameters ("A book recommendation system." )
42- .setMessages ("Book" , userMessages )
43- .setTemperature (0.0f )
44- .setStopStrings ("\" \" \" " )
45- .setNPredict (10 )
46- .setSeed (42 );
47-
48- String assistantResponse = model .completeChat (params );
49-
50- Assert .assertNotNull (assistantResponse );
51-
52- Assert .assertEquals (params .get ("prompt" ), "\" <|im_start|>system\\ nBook<|im_end|>\\ n<|im_start|>user\\ nWhat is the best book for machine learning?<|im_end|>\\ n<|im_start|>assistant\\ n\" " );
53-
54- userMessages .add (new Pair <>("assistant" , assistantResponse ));
55- userMessages .add (new Pair <>("user" , "that is great book for machine learning?, what about linear algebra" ));
56-
57- params = new InferenceParameters ("A book recommendation system." )
58- .setMessages ("Book" , userMessages )
59- .setTemperature (0.0f )
60- .setStopStrings ("\" \" \" " )
61- .setNPredict (10 )
62- .setSeed (42 );
63-
64-
65- assistantResponse = model .completeChat (params );
66- Assert .assertNotNull (assistantResponse );
67-
68- Assert .assertEquals (params .get ("prompt" ), "\" <|im_start|>system\\ nBook<|im_end|>\\ n<|im_start|>user\\ nWhat is the best book for machine learning?<|im_end|>\\ n<|im_start|>assistant\\ nWhat is the best book for machine learning?<<|im_end|>\\ n<|im_start|>user\\ nthat is great book for machine learning?, what about linear algebra<|im_end|>\\ n<|im_start|>assistant\\ n\" " );
69-
70-
58+ public void testEmptyInput () {
59+ List <Pair <String , String >> userMessages = new ArrayList <>();
60+ userMessages .add (new Pair <>("user" , "" ));
61+
62+ InferenceParameters params = new InferenceParameters ("A book recommendation system." )
63+ .setMessages ("Book" , userMessages )
64+ .setTemperature (0.5f )
65+ .setNPredict (20 );
66+
67+ String response = model .completeChat (params );
68+ Assert .assertNotNull (response );
69+ Assert .assertFalse (response .isEmpty ());
70+ }
71+
72+ @ Test
73+ public void testStopString () {
74+ List <Pair <String , String >> userMessages = new ArrayList <>();
75+ userMessages .add (new Pair <>("user" , "Tell me about AI ethics." ));
76+
77+ InferenceParameters params = new InferenceParameters ("A book recommendation system." )
78+ .setMessages ("AI" , userMessages )
79+ .setStopStrings ("\" \" \" " ) // Ensures stopping at proper place
80+ .setTemperature (0.7f )
81+ .setNPredict (50 );
82+
83+ String response = model .completeChat (params );
84+ Assert .assertNotNull (response );
85+ Assert .assertFalse (response .contains ("\" \" \" " ));
86+ }
87+
88+ @ Test
89+ public void testFixedSeed () {
90+ List <Pair <String , String >> userMessages = new ArrayList <>();
91+ userMessages .add (new Pair <>("user" , "What is reinforcement learning?" ));
92+
93+ InferenceParameters params = new InferenceParameters ("AI Chatbot." )
94+ .setMessages ("AI" , userMessages )
95+ .setTemperature (0.7f )
96+ .setSeed (42 ) // Fixed seed for reproducibility
97+ .setNPredict (50 );
98+
99+ String response1 = model .completeChat (params );
100+ String response2 = model .completeChat (params );
101+
102+ Assert .assertEquals (response1 , response2 ); // Responses should be identical
103+ }
104+
105+ @ Test
106+ public void testNonEnglishInput () {
107+ List <Pair <String , String >> userMessages = new ArrayList <>();
108+ userMessages .add (new Pair <>("user" , "Quel est le meilleur livre sur l'apprentissage automatique ?" ));
109+
110+ InferenceParameters params = new InferenceParameters ("A book recommendation system." )
111+ .setMessages ("Book" , userMessages )
112+ .setTemperature (0.7f )
113+ .setNPredict (50 );
114+
115+ String response = model .completeChat (params );
116+ Assert .assertNotNull (response );
117+ Assert .assertTrue (response .length () > 5 ); // Ensure some response is generated
71118 }
72119
73120}
0 commit comments