diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 20a7ab99c8d..8e9b2145951 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -94,6 +94,7 @@ def __init__( # noqa: C901 get_example_inputs: Callable, use_i64_token: bool, seq_mse_candidates: int, + skip_generate: bool = False, ): # n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call assert max_seq_length is not None, "max_seq_length must be provided" @@ -107,6 +108,7 @@ def __init__( # noqa: C901 self.max_seq_length = max_seq_length self.use_i64_token = use_i64_token self.seq_mse_candidates = seq_mse_candidates + self.skip_generate = skip_generate def _model_call(self, inps): all_logits = None @@ -114,6 +116,7 @@ def _model_call(self, inps): if self._use_kv_cache: kwargs["ar_len"] = self.ar_len kwargs["seq_mse_candidates"] = self.seq_mse_candidates + kwargs["skip_generate"] = self.skip_generate all_logits = INFERENCE_REGISTRY[self._use_kv_cache]( self.get_example_inputs, @@ -675,6 +678,7 @@ def kv_inference( # noqa: C901 collect_logits=False, seq_mse_candidates=0, lookahead_config=None, + skip_generate=False, ): is_multimodal = all( [ @@ -781,19 +785,22 @@ def kv_inference( # noqa: C901 # Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached. # When run on wikitext for ppl evaluation, this while-loop is not expected to run. - _generate( - inputs, - cur_pos, - module, - tokenizer, - tok_embedding, - ar_len, - max_seq_len, - k_caches, - v_caches, - total_token_list, - lookahead_config, - ) + # During calibration, skip_generate=True skips this loop since quantization + # observers already have sufficient activation statistics from the prefill pass. + if not skip_generate: + _generate( + inputs, + cur_pos, + module, + tokenizer, + tok_embedding, + ar_len, + max_seq_len, + k_caches, + v_caches, + total_token_list, + lookahead_config, + ) logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}") if collect_logits: @@ -900,6 +907,7 @@ def graph_module_inference( event_name: Optional[str] = None, seq_mse_candidates: int = 0, lookahead_config: Optional[Tuple[int]] = None, + skip_generate: bool = False, ): """ This function supports model execution from static nn.Module decoder model @@ -915,6 +923,7 @@ def graph_module_inference( if use_kv_cache: kwargs["ar_len"] = ar_len kwargs["lookahead_config"] = lookahead_config + kwargs["skip_generate"] = skip_generate INFERENCE_REGISTRY[use_kv_cache]( get_example_inputs, @@ -940,6 +949,7 @@ def graph_module_inference( get_example_inputs=get_example_inputs, use_i64_token=use_i64_token, seq_mse_candidates=seq_mse_candidates, + skip_generate=skip_generate, ) # Evaluate the model with torch.no_grad(): diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index 44917e0bd5a..7ef773463b0 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -505,6 +505,7 @@ def _calibrate( use_i64_token=self.control_args.embedding_quantize is not None, event_name=f"{event}_tasks", seq_mse_candidates=self.config.seq_mse_candidates, + skip_generate=True, ) # prepare lookahead config if applicable @@ -533,6 +534,7 @@ def _calibrate( use_i64_token=self.control_args.embedding_quantize is not None, event_name=f"{event}_prompt", lookahead_config=lookahead_config, + skip_generate=True, ) @log_info