Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions examples/qualcomm/oss_scripts/llama/decoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__( # noqa: C901
get_example_inputs: Callable,
use_i64_token: bool,
seq_mse_candidates: int,
skip_generate: bool = False,
):
# n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
assert max_seq_length is not None, "max_seq_length must be provided"
Expand All @@ -107,13 +108,15 @@ def __init__( # noqa: C901
self.max_seq_length = max_seq_length
self.use_i64_token = use_i64_token
self.seq_mse_candidates = seq_mse_candidates
self.skip_generate = skip_generate

def _model_call(self, inps):
all_logits = None
kwargs = {}
if self._use_kv_cache:
kwargs["ar_len"] = self.ar_len
kwargs["seq_mse_candidates"] = self.seq_mse_candidates
kwargs["skip_generate"] = self.skip_generate

all_logits = INFERENCE_REGISTRY[self._use_kv_cache](
self.get_example_inputs,
Expand Down Expand Up @@ -675,6 +678,7 @@ def kv_inference( # noqa: C901
collect_logits=False,
seq_mse_candidates=0,
lookahead_config=None,
skip_generate=False,
):
is_multimodal = all(
[
Expand Down Expand Up @@ -781,19 +785,22 @@ def kv_inference( # noqa: C901

# Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
# When run on wikitext for ppl evaluation, this while-loop is not expected to run.
_generate(
inputs,
cur_pos,
module,
tokenizer,
tok_embedding,
ar_len,
max_seq_len,
k_caches,
v_caches,
total_token_list,
lookahead_config,
)
# During calibration, skip_generate=True skips this loop since quantization
# observers already have sufficient activation statistics from the prefill pass.
if not skip_generate:
_generate(
inputs,
cur_pos,
module,
tokenizer,
tok_embedding,
ar_len,
max_seq_len,
k_caches,
v_caches,
total_token_list,
lookahead_config,
)

logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}")
if collect_logits:
Expand Down Expand Up @@ -900,6 +907,7 @@ def graph_module_inference(
event_name: Optional[str] = None,
seq_mse_candidates: int = 0,
lookahead_config: Optional[Tuple[int]] = None,
skip_generate: bool = False,
):
"""
This function supports model execution from static nn.Module decoder model
Expand All @@ -915,6 +923,7 @@ def graph_module_inference(
if use_kv_cache:
kwargs["ar_len"] = ar_len
kwargs["lookahead_config"] = lookahead_config
kwargs["skip_generate"] = skip_generate

INFERENCE_REGISTRY[use_kv_cache](
get_example_inputs,
Expand All @@ -940,6 +949,7 @@ def graph_module_inference(
get_example_inputs=get_example_inputs,
use_i64_token=use_i64_token,
seq_mse_candidates=seq_mse_candidates,
skip_generate=skip_generate,
)
# Evaluate the model
with torch.no_grad():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def _calibrate(
use_i64_token=self.control_args.embedding_quantize is not None,
event_name=f"{event}_tasks",
seq_mse_candidates=self.config.seq_mse_candidates,
skip_generate=True,
)

# prepare lookahead config if applicable
Expand Down Expand Up @@ -533,6 +534,7 @@ def _calibrate(
use_i64_token=self.control_args.embedding_quantize is not None,
event_name=f"{event}_prompt",
lookahead_config=lookahead_config,
skip_generate=True,
)

@log_info
Expand Down
Loading