Skip to content

Commit 4e45a96

Browse files
authored
Merge pull request #1 from fjfricke/inline_probabilities
add first version where you can get logprobs from tokens in argmax when using where in set
2 parents dfb179f + 2d32d48 commit 4e45a96

File tree

5 files changed

+79
-9
lines changed

5 files changed

+79
-9
lines changed

src/lmql/runtime/dclib/dclib_array.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class Continuation:
1212
token: Any
1313
logprob: Any
1414
user_data: Any
15+
distribution_logprobs: Any = None
1516
class criterion:
1617
def __and__(self, other):
1718
return logical_and(self, other)
@@ -325,8 +326,9 @@ def op_extend(p1, p2):
325326
tokens = continuation.token.reshape(-1)
326327
logprobs = continuation.logprob.reshape(-1)
327328
user_data = continuation.user_data or [None] * len(tokens)
328-
for t,s,u in zip(tokens, logprobs, user_data):
329-
extended_seqs.append(sq.extend(Continuation(t, s, u)))
329+
distribution_logprobs = continuation.distribution_logprobs or [None] * len(tokens)
330+
for t,s,u,d in zip(tokens, logprobs, user_data, distribution_logprobs):
331+
extended_seqs.append(sq.extend(Continuation(t, s, u, d)))
330332
return extended_seqs
331333

332334
return DataArray(apply_componentwise(op_extend, self.sequences, other.sequences, "extend", allow_mismatch_keys=False), dims=self.shape)

src/lmql/runtime/dclib/dclib_seq.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def json(self, diff: bool = False):
8282
}
8383

8484
class DecoderSequence:
85-
def __init__(self, input_ids_or_str, logprobs=None, deterministic=None, stop_phrase=None, predecessor=None, user_data=None, sticky_user_data_keys=None, epsilon_node=False, internal=False):
85+
def __init__(self, input_ids_or_str, logprobs=None, deterministic=None, stop_phrase=None, predecessor=None, user_data=None, sticky_user_data_keys=None, epsilon_node=False, internal=False, distribution_logprobs=None):
8686
if logprobs is not None:
8787
if not all([p > get_truncation_threshold() for p in logprobs]):
8888
warnings.warn("logprobs contain values below the current logprob truncation threshold {t}, which may cause unexpected behavior. Consider increasing the truncation threshold via lmql.model(..., truncation_threshold=...).".format(t=get_truncation_threshold()))
@@ -141,6 +141,10 @@ def __init__(self, input_ids_or_str, logprobs=None, deterministic=None, stop_phr
141141

142142
# indicates to dc.rewrite whether this sequence can be rewritten
143143
self.needs_rewrite = True
144+
if not distribution_logprobs:
145+
self.distribution_logprobs = [None] * len(self.logprobs)
146+
else:
147+
self.distribution_logprobs = distribution_logprobs
144148

145149
def __hash__(self) -> int:
146150
return hash(self.id)
@@ -371,7 +375,8 @@ def extend(self, continuation, internal=False):
371375
predecessor=self,
372376
user_data=self.extend_user_data(continuation),
373377
sticky_user_data_keys=self.sticky_user_data_keys,
374-
internal=internal
378+
internal=internal,
379+
distribution_logprobs=self.distribution_logprobs + [continuation.distribution_logprobs]
375380
)
376381

377382
def detect_stop_phrase(self, continuation):
@@ -436,14 +441,18 @@ def make_successors(self, next_tokens, next_token_scores, logits, user_data=None
436441
tokens = [t for t, s in zip(next_tokens, next_token_scores) if s > get_truncation_threshold()]
437442
scores = [s for s in next_token_scores if s > get_truncation_threshold()]
438443

444+
distribution_logprobs = [{k: v for k, v in logits.probs.items() if type(k) == str}]
445+
if len(distribution_logprobs[0]) < 1:
446+
distribution_logprobs = None
447+
439448
if len(tokens) == 0:
440449
print("WARNING: all continuation token fall below the current logprob truncation threshold {t}. This is likely due to a too low truncation threshold. Please increase the truncation threshold via lmql.model(..., truncation_threshold=...).".format(t=get_truncation_threshold()))
441450
tokens = [t for t, s in zip(next_tokens, next_token_scores)][:1]
442451
scores = [s for s in next_token_scores][:1]
443452
next_tokens = np.stack(tokens, axis=0)
444453
next_token_scores = np.stack(scores, axis=0)
445454

446-
return Continuation(next_tokens, next_token_scores, user_data)
455+
return Continuation(next_tokens, next_token_scores, user_data, distribution_logprobs)
447456
# global counter for all sequences created in this process for identification purposes
448457
DecoderSequence.seq_ctr = 0
449458
DecoderSequence.graph = None

src/lmql/runtime/interpreter.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class PromptState(NamedTuple):
102102
where: Optional[Any]
103103
tail: Optional[str]
104104

105+
distribution_logprobs: Optional[Dict[str, float]] = {}
106+
105107
def __str__(self):
106108
return f"<PromptState '{self.variable}' '{[self.prompt]}'>"
107109

@@ -546,15 +548,20 @@ async def where_step_for_sequence(self, s: dc.DecoderSequence, needs_masking, se
546548
# update hint for max_tokens to generate for current var
547549
max_tokens_hint = ops.most_restrictive_hint([sub_max_token_hints, max_tokens_hint])
548550

551+
if len(s.distribution_logprobs) > variable_offset:
552+
scores = s.distribution_logprobs[variable_offset]
553+
else:
554+
scores = None
555+
549556
# current context
550557
program_state: ProgramState = state.program_state.copy()
551-
program_state.set(variable, text, scores=(), diff=diff_text, montonicity="inc", tokens=text_tokens)
558+
program_state.set(variable, text, scores=scores, diff=diff_text, montonicity="inc", tokens=text_tokens)
552559
program_state.subinterpreter_results = subvalid
553560
program_state.prompt = state.prompt
554561

555562
# follow context
556563
follow_program_state: ProgramState = state.program_state.copy()
557-
follow_program_state.set(variable, text + str(ops.NextToken), scores=(), diff=diff_text, montonicity="inc", tokens=text_tokens)
564+
follow_program_state.set(variable, text + str(ops.NextToken), scores=scores, diff=diff_text, montonicity="inc", tokens=text_tokens)
558565
follow_program_state.subinterpreter_results = subfollow
559566
follow_program_state.prompt = state.prompt
560567

@@ -614,6 +621,7 @@ async def where_step_for_sequence(self, s: dc.DecoderSequence, needs_masking, se
614621
program_state=program_state,
615622
stopping_phrases=stopping_phrases,
616623
where=await self.where_graph_with_trace(where, trace, follow_trace),
624+
distribution_logprobs=scores,
617625
)
618626

619627
# extract hint of maximum number of tokens to generate for 'variable' from
@@ -761,7 +769,7 @@ async def rewrite_for_sequence(self, seq: dc.DecoderSequence, needs_rewrite, ass
761769

762770
variable_value = text
763771
# set raw variable value
764-
program_state.set(variable, variable_value, scores=(), diff=text_diff, montonicity="fin", tokens=text_tokens)
772+
program_state.set(variable, variable_value, scores=state.distribution_logprobs, diff=text_diff, montonicity="fin", tokens=text_tokens)
765773

766774
where = state.full_where_condition(self)
767775

@@ -1022,6 +1030,9 @@ async def debug_out(decoder_step):
10221030
if _DCLibDebugPrinter.printer.records_graph:
10231031
dc.set_record_graph()
10241032
self.decoder_graph = dc.DecoderSequence.graph
1033+
if self.model.adapter.decoder_args.get("decoder_graph", False):
1034+
dc.set_record_graph()
1035+
self.decoder_graph = dc.DecoderSequence.graph
10251036

10261037
# get decoder function
10271038
mode = decoder_args["decoder"].lower()

src/lmql/runtime/openai_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ async def op_sample(seqs):
557557
sampling_modes = [f"sample-{temperature}-sample-id-{random.randint(0, 2**32-1)}" for _ in range(len(seqs))]
558558
edge_type_populated_user_data = [{"dc-edge-type": sm} for sm in sampling_modes]
559559

560-
completions: List[CompletionResult] = await self.completion_buffer(seqs, logprobs=num_samples, sampling_modes=sampling_modes, **kwargs)
560+
completions: List[CompletionResult] = await self.completion_buffer(seqs, logprobs=5, sampling_modes=sampling_modes, **kwargs)
561561

562562
next_token_ids = []
563563
next_token_scores = []

test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import sys
2+
sys.path.append('/Users/felix/Programming/lmql/src')
3+
import lmql
4+
import asyncio
5+
6+
#add replicate api key to env
7+
import os
8+
os.environ['REPLICATE_API_TOKEN'] = 'r8_aOlrg82Wfg30Rx4L4mv9wI2npPfBQGO0Pvci4'
9+
10+
# def test_decorator(variable_value, prompt_value, context):
11+
# return variable_value, prompt_value
12+
13+
async def main():
14+
15+
test = lmql.model(
16+
"openai/gpt-3.5-turbo-instruct"
17+
# "meta-llama/Llama-2-13b-chat-hf",
18+
# endpoint="replicate:deployment/ml-delphai/llama2-13b-chat-lmtp",
19+
# endpoint="replicate:charles-dyfis-net/llama-2-7b-chat-hf--lmtp-8bit",
20+
# tokenizer="AyyYOO/Luna-AI-Llama2-Uncensored-FP16-sharded",
21+
)
22+
pass
23+
24+
answer = await lmql.run(
25+
"""
26+
import math
27+
def get_probs(variable_value, prompt_value, context):
28+
breakpoint()
29+
logprob_scores = list(context.variable_scores.items())[-1][1]
30+
scores = dict()
31+
for key, value in logprob_scores.items():
32+
if value > -5:
33+
scores[key] = math.exp(value)
34+
return scores
35+
argmax(verbose=True)
36+
\"How much you like monkeys between 0 and 2?[@get_probs MONKEY]\" where MONKEY in set ([\"0\", \"1\", \"2\"])
37+
\"How much you like birds between 0 and 2?[@get_probs BIRD]\" where BIRD in set ([\"0\", \"1\", \"2\"])
38+
return (MONKEY, BIRD)
39+
""",
40+
max_len=4000,
41+
model=test,
42+
# decoder_graph=True,
43+
)
44+
45+
print(answer)
46+
47+
if __name__ == "__main__":
48+
asyncio.run(main())

0 commit comments

Comments
 (0)