Skip to content

Commit 992c82b

Browse files
committed
Merge branch 'greedy_sample' of https://github.com/ModelTC/lightllm into fa3_mtp
2 parents 3f9fed0 + ae83cd4 commit 992c82b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

lightllm/server/router/model_infer/mode_backend/generic_post_process.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
6464

6565
if is_all_greedy:
6666
batch_next_token_ids = torch.argmax(logits, -1)
67-
batch_next_token_probs = torch.nn.functional.log_softmax(logits, dim=-1)
67+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
68+
batch_next_token_probs = torch.gather(log_probs, dim=1, index=batch_next_token_ids.view(-1, 1))
6869
return batch_next_token_ids.view(-1), batch_next_token_probs.view(-1)
6970

7071
elif get_env_start_args().sampling_backend == "triton":

0 commit comments

Comments
 (0)