11package de .kherud .llama ;
22
3- import java . lang . annotation . Native ;
3+
44import java .util .Iterator ;
55import java .util .NoSuchElementException ;
66
77/**
8- * This iterator is used by {@link LlamaModel#generate(InferenceParameters)}. In addition to implementing {@link Iterator},
9- * it allows to cancel ongoing inference (see {@link #cancel()}).
8+ * Iterates over a stream of outputs from the model
109 */
11- public final class LlamaIterator implements Iterator <LlamaOutput > {
10+ public class LlamaIterator implements Iterator <LlamaOutput > {
1211
1312 private final LlamaModel model ;
13+ private final boolean isChat ;
1414 private final int taskId ;
1515
16- @ Native
17- @ SuppressWarnings ("FieldMayBeFinal" )
18- private boolean hasNext = true ;
16+ /**
17+ * Whether there is a next token to receive
18+ */
19+ public boolean hasNext = true ;
1920
20- LlamaIterator (LlamaModel model , InferenceParameters parameters ) {
21+ /**
22+ * Creates a new iterator
23+ *
24+ * @param model the llama model to use for generating
25+ * @param parameters parameters for the inference
26+ * @param isChat whether this is a chat completion (true) or regular
27+ * completion (false)
28+ */
29+ LlamaIterator (LlamaModel model , InferenceParameters parameters , boolean isChat ) {
2130 this .model = model ;
22- parameters .setStream (true );
23- taskId = model .requestCompletion (parameters .toString ());
31+ this .isChat = isChat ;
32+
33+ if (isChat ) {
34+ String prompt = model .applyTemplate (parameters );
35+ parameters .setPrompt (prompt );
36+ this .taskId = model .requestChat (parameters .toString ());
37+ } else {
38+ this .taskId = model .requestCompletion (parameters .toString ());
39+ }
2440 }
2541
2642 @ Override
@@ -33,19 +49,38 @@ public LlamaOutput next() {
3349 if (!hasNext ) {
3450 throw new NoSuchElementException ();
3551 }
36- LlamaOutput output = model .receiveCompletion (taskId );
37- hasNext = !output .stop ;
38- if (output .stop ) {
39- model .releaseTask (taskId );
52+
53+ try {
54+ if (isChat ) {
55+ String response = model .streamChatCompletion (taskId );
56+ // Check for completion by examining the JSON response
57+ // This is a simplification - the actual implementation might need more
58+ // sophisticated handling
59+ if (response != null && response .contains ("\" finish_reason\" :" )) {
60+ hasNext = false ;
61+ }
62+ return new LlamaOutput (response , !hasNext );
63+ } else {
64+ StreamingOutput output = model .streamCompletion (taskId );
65+ hasNext = !output .isFinal ;
66+ return new LlamaOutput (output .text , output .isFinal );
67+ }
68+ } catch (Exception e ) {
69+ model .releaseTask (taskId );
70+ hasNext = false ;
71+ throw new RuntimeException (e );
4072 }
41- return output ;
4273 }
4374
4475 /**
45- * Cancel the ongoing generation process.
76+ * Cancel the ongoing generation process. This will stop the model from
77+ * generating more tokens and release resources.
4678 */
4779 public void cancel () {
48- model .cancelCompletion (taskId );
49- hasNext = false ;
80+ if (hasNext ) {
81+ model .cancelCompletion (taskId );
82+ model .releaseTask (taskId );
83+ hasNext = false ;
84+ }
5085 }
51- }
86+ }
0 commit comments