Skip to content

Commit a61dbee

Browse files
committed
add ability to return logprobs for lmtp models
1 parent eeb2c0f commit a61dbee

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

src/lmql/models/lmtp/lmtp_dcmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def make_logits(self, payload):
224224
async def singleton_result(self, token, score):
225225
yield {"token": token, "logprob": score, "top_logprobs": {token: score}}
226226

227-
async def generate(self, s, temperature, top_logprobs = 1, chunk_size=None, **kwargs):
227+
async def generate(self, s, temperature, top_logprobs = 5, chunk_size=None, **kwargs):
228228
kwargs = {**self.model_args, **kwargs}
229229

230230
# get token masks from interpreter

src/lmql/runtime/dclib/dclib_seq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def make_successors(self, next_tokens, next_token_scores, logits, user_data=None
441441
tokens = [t for t, s in zip(next_tokens, next_token_scores) if s > get_truncation_threshold()]
442442
scores = [s for s in next_token_scores if s > get_truncation_threshold()]
443443

444-
distribution_logprobs = [{k: v for k, v in logits.probs.items() if type(k) == str}]
444+
distribution_logprobs = [{get_tokenizer().decode([k]): v for k, v in logits.probs.items() if type(k) == int and v > -10}]
445445
if len(distribution_logprobs[0]) < 1:
446446
distribution_logprobs = None
447447

test.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,19 @@
55

66
#add replicate api key to env
77
import os
8-
os.environ['REPLICATE_API_TOKEN'] = 'r8_aOlrg82Wfg30Rx4L4mv9wI2npPfBQGO0Pvci4'
98

109
# def test_decorator(variable_value, prompt_value, context):
1110
# return variable_value, prompt_value
1211

1312
async def main():
1413

1514
test = lmql.model(
16-
"openai/gpt-3.5-turbo-instruct"
17-
# "meta-llama/Llama-2-13b-chat-hf",
18-
# endpoint="replicate:deployment/ml-delphai/llama2-13b-chat-lmtp",
15+
# "openai/gpt-3.5-turbo-instruct"
16+
"meta-llama/Llama-2-13b-chat-hf",
17+
endpoint="replicate:deployment/ml-delphai/llama2-13b-chat-lmtp",
1918
# endpoint="replicate:charles-dyfis-net/llama-2-7b-chat-hf--lmtp-8bit",
20-
# tokenizer="AyyYOO/Luna-AI-Llama2-Uncensored-FP16-sharded",
19+
tokenizer="AyyYOO/Luna-AI-Llama2-Uncensored-FP16-sharded",
2120
)
22-
pass
2321

2422
answer = await lmql.run(
2523
"""
@@ -32,7 +30,7 @@ def get_probs(variable_value, prompt_value, context):
3230
if value > -5:
3331
scores[key] = math.exp(value)
3432
return scores
35-
argmax(verbose=True)
33+
argmax
3634
\"How much you like monkeys between 0 and 2?[@get_probs MONKEY]\" where MONKEY in set ([\"0\", \"1\", \"2\"])
3735
\"How much you like birds between 0 and 2?[@get_probs BIRD]\" where BIRD in set ([\"0\", \"1\", \"2\"])
3836
return (MONKEY, BIRD)

0 commit comments

Comments
 (0)