Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 1 addition & 36 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import Node
from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver
from torchao.quantization.pt2e import MinMaxObserver
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)

Expand Down Expand Up @@ -92,40 +91,6 @@ def annotate_mimi_decoder(gm: torch.fx.GraphModule):
break


def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
for node in gm.graph.nodes:
if node.op == "output":
for index, prefill_output in enumerate(node.args[0]):
kv_quant_attr = kv_quant_attrs[index]
fixed_observer = FixedQParamsObserver.with_args(
scale=kv_quant_attr[0],
zero_point=kv_quant_attr[1],
quant_min=kv_quant_attr[2],
quant_max=kv_quant_attr[3],
dtype=kv_quant_attr[4],
qscheme=torch.torch.per_tensor_affine,
)

fixed_output_spec = QuantizationSpec(
quant_min=kv_quant_attr[2],
quant_max=kv_quant_attr[3],
dtype=kv_quant_attr[4],
ch_axis=0,
observer_or_fake_quant_ctr=fixed_observer,
)

input_qspec_map = {}
for input in prefill_output.args:
if isinstance(input, Node):
input_qspec_map[input] = fixed_output_spec

prefill_output.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=fixed_output_spec,
_annotated=True,
)


def annotate_kv_8bit( # noqa: C901
gm: torch.fx.GraphModule,
is_qat=False,
Expand Down
8 changes: 3 additions & 5 deletions examples/qualcomm/oss_scripts/llama/model/static_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,9 +764,7 @@ def forward(

def get_example_inputs(self):
dtype = torch.int64 if self.use_i64_token else torch.int32
tokens = torch.randint(
self.vocab_size, (self.max_batch_size, self.ar_len), dtype=dtype
)
tokens = torch.ones((self.max_batch_size, self.ar_len), dtype=dtype)
atten_mask = AttentionMask(
CausalAttentionMask(self.max_batch_size, self.ar_len, self.max_context_len)
)
Expand All @@ -776,15 +774,15 @@ def get_example_inputs(self):
for _ in range(self.n_layers):
# transpose first to decrease the runtime efforts
k_cache.append(
torch.zeros(
torch.ones(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we initialized kv cache with different values?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for checking the numerical value with deterministic input for validation purpose. I can revert them back to zeros.

self.max_batch_size,
self.n_kv_heads,
self.head_dim,
self.max_context_len - self.ar_len,
)
)
v_cache.append(
torch.zeros(
torch.ones(
self.max_batch_size,
self.n_kv_heads,
self.max_context_len - self.ar_len,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,108 +294,6 @@ Error MultimodalRunner<T>::load() {
cache_mode_,
static_cast<int32_t>(dim)});

if (eval_mode_ == EvalMode::kLookaheadDecoding ||
eval_mode_ == EvalMode::kHybrid) {
output_k_cache_scales_.resize(num_layers);
output_k_cache_zero_points_.resize(num_layers);
output_v_cache_scales_.resize(num_layers);
output_v_cache_zero_points_.resize(num_layers);
for (int i = 0; i < num_layers; i++) {
std::string get_k_scale_output_name =
"get_k_scale_output_" + std::to_string(i);
std::string get_k_zero_point_output_name =
"get_k_zero_point_output_" + std::to_string(i);
std::string get_v_scale_output_name =
"get_v_scale_output_" + std::to_string(i);
std::string get_v_zero_point_output_name =
"get_v_zero_point_output_" + std::to_string(i);

if (module_->method_names()->count(get_k_scale_output_name) > 0) {
output_k_cache_scales_[i] = static_cast<float>(
ET_UNWRAP(module_->get(get_k_scale_output_name)).toDouble());
} else {
ET_LOG(Error, "Cannot find method %s", get_k_scale_output_name.c_str());
return Error::Internal;
}
if (module_->method_names()->count(get_k_zero_point_output_name) > 0) {
output_k_cache_zero_points_[i] = static_cast<T>(
ET_UNWRAP(module_->get(get_k_zero_point_output_name)).toInt());
} else {
ET_LOG(
Error,
"Cannot find method %s",
get_k_zero_point_output_name.c_str());
return Error::Internal;
}
if (module_->method_names()->count(get_v_scale_output_name) > 0) {
output_v_cache_scales_[i] = static_cast<float>(
ET_UNWRAP(module_->get(get_v_scale_output_name)).toDouble());
} else {
ET_LOG(Error, "Cannot find method %s", get_v_scale_output_name.c_str());
return Error::Internal;
}
if (module_->method_names()->count(get_v_zero_point_output_name) > 0) {
output_v_cache_zero_points_[i] = static_cast<T>(
ET_UNWRAP(module_->get(get_v_zero_point_output_name)).toInt());
} else {
ET_LOG(
Error,
"Cannot find method %s",
get_v_zero_point_output_name.c_str());
return Error::Internal;
}
}
// Load scale and zero point for quantized input KV cache
input_k_cache_scales_.resize(num_layers);
input_k_cache_zero_points_.resize(num_layers);
input_v_cache_scales_.resize(num_layers);
input_v_cache_zero_points_.resize(num_layers);
for (int i = 0; i < num_layers; i++) {
std::string get_k_scale_input_name =
"get_k_scale_input_" + std::to_string(i);
std::string get_k_zero_point_input_name =
"get_k_zero_point_input_" + std::to_string(i);
std::string get_v_scale_input_name =
"get_v_scale_input_" + std::to_string(i);
std::string get_v_zero_point_input_name =
"get_v_zero_point_input_" + std::to_string(i);
if (module_->method_names()->count(get_k_scale_input_name) > 0) {
input_k_cache_scales_[i] = static_cast<float>(
ET_UNWRAP(module_->get(get_k_scale_input_name)).toDouble());
} else {
ET_LOG(Error, "Cannot find method %s", get_k_scale_input_name.c_str());
return Error::Internal;
}
if (module_->method_names()->count(get_k_zero_point_input_name) > 0) {
input_k_cache_zero_points_[i] = static_cast<T>(
ET_UNWRAP(module_->get(get_k_zero_point_input_name)).toInt());
} else {
ET_LOG(
Error,
"Cannot find method %s",
get_k_zero_point_input_name.c_str());
return Error::Internal;
}
if (module_->method_names()->count(get_v_scale_input_name) > 0) {
input_v_cache_scales_[i] = static_cast<float>(
ET_UNWRAP(module_->get(get_v_scale_input_name)).toDouble());
} else {
ET_LOG(Error, "Cannot find method %s", get_v_scale_input_name.c_str());
return Error::Internal;
}
if (module_->method_names()->count(get_v_zero_point_input_name) > 0) {
input_v_cache_zero_points_[i] = static_cast<T>(
ET_UNWRAP(module_->get(get_v_zero_point_input_name)).toInt());
} else {
ET_LOG(
Error,
"Cannot find method %s",
get_v_zero_point_input_name.c_str());
return Error::Internal;
}
}
}

// Initialize EmbeddingGenerator
embedding_generator_ = std::make_unique<EmbeddingProcessor>(
embedding_runner_.get(),
Expand Down Expand Up @@ -599,46 +497,6 @@ Error MultimodalRunner<T>::generate_from_prompt_or_file(
// start the main loop
prompt_tokens.push_back(cur_token);

// Requant kv cache for prefill decode I/O
if (eval_mode_ == EvalMode::kLookaheadDecoding ||
eval_mode_ == EvalMode::kHybrid) {
int64_t num_heads = prompt_processor_->get_num_heads();
int64_t num_layers = prompt_processor_->get_num_layers();
int64_t head_dim = kv_manager_->get_head_dim();
std::vector<KVCache<T>> k_cache_ptrs = kv_manager_->get_k_cache_();
std::vector<KVCache<T>> v_cache_ptrs = kv_manager_->get_v_cache_();

const int64_t num_elems_per_layer =
(context_len_ - 1) * num_heads * head_dim;
// Requant kv cache from prefill output scale/zero_point to decode input
// scale/zero_point
for (int layer_idx = 0; layer_idx < num_layers; layer_idx++) {
T* k_cache_data = k_cache_ptrs[layer_idx].buffer;
T* v_cache_data = v_cache_ptrs[layer_idx].buffer;

const float scale_ratio_k =
output_k_cache_scales_[layer_idx] / input_k_cache_scales_[layer_idx];
const float scale_ratio_v =
output_v_cache_scales_[layer_idx] / input_v_cache_scales_[layer_idx];

for (int64_t i = 0; i < num_elems_per_layer; i++) {
// Requant k_cache_data from prefill output scale/zero_point to decode
// input scale/zero_point
k_cache_data[i] = static_cast<T>(
(k_cache_data[i] - output_k_cache_zero_points_[layer_idx]) *
scale_ratio_k +
input_k_cache_zero_points_[layer_idx]);

// Requant v_cache_data from prefill output scale/zero_point to decode
// input scale/zero_point
v_cache_data[i] = static_cast<T>(
(v_cache_data[i] - output_v_cache_zero_points_[layer_idx]) *
scale_ratio_v +
input_v_cache_zero_points_[layer_idx]);
}
}
}

int64_t num_generated_tokens = ET_UNWRAP(token_generator_->generate(
prompt_tokens, cur_pos_, seq_len, token_callback, dump_logits, nullptr));
stats_.inference_end_ms = time_in_ms();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,6 @@ class MultimodalRunner : public executorch::extension::llm::IRunner {
multimodal_embeddings_dim_order_;
TensorStruct<float> merged_embeddings_;

// scale and zero point for quantized KV cache
std::vector<float> input_k_cache_scales_;
std::vector<T> input_k_cache_zero_points_;
std::vector<float> input_v_cache_scales_;
std::vector<T> input_v_cache_zero_points_;
std::vector<float> output_k_cache_scales_;
std::vector<T> output_k_cache_zero_points_;
std::vector<float> output_v_cache_scales_;
std::vector<T> output_v_cache_zero_points_;

// stats
executorch::llm::Stats stats_;
};
Expand Down
31 changes: 0 additions & 31 deletions examples/qualcomm/oss_scripts/llama/wrappers/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,37 +39,6 @@ class Mode(Enum):
DECODE = 2


def is_node_src_start_with_name(node: torch.fx.Node, prefix: str) -> bool:
"""
Return True if any NodeSource in node.meta['from_node']
has a `name` starting with `prefix`.
"""

def has_source_name_prefix(
node_src: torch.fx.traceback.NodeSource, prefix: str
) -> bool:

name = getattr(node_src, "name", None)
if isinstance(name, str) and name.startswith(prefix):
return True

children = getattr(node_src, "from_node", None)
if not children:
return False

for src in children:
if has_source_name_prefix(src, prefix):
return True

return False

node_srcs = node.meta.get("from_node", None)
if not node_srcs:
return False

return any(has_source_name_prefix(node_src, prefix) for node_src in node_srcs)


def log_info(func):
class TimeIt:
def __init__(self, event):
Expand Down
Loading
Loading