From d279afea73347ea49ad044c36bfd801ff804fae7 Mon Sep 17 00:00:00 2001 From: Connor Dilgren Date: Mon, 21 Apr 2025 14:10:42 -0400 Subject: [PATCH 1/4] toy matmul 2:4 sparsity experiment --- toy_matmul.err | 2 + toy_matmul.out | 69 ++++++++++++++++++++++++++ toy_matmul.py | 131 +++++++++++++++++++++++++++++++++++++++++++++++++ toy_matmul.sh | 25 ++++++++++ 4 files changed, 227 insertions(+) create mode 100644 toy_matmul.err create mode 100644 toy_matmul.out create mode 100644 toy_matmul.py create mode 100755 toy_matmul.sh diff --git a/toy_matmul.err b/toy_matmul.err new file mode 100644 index 0000000..8dc65fd --- /dev/null +++ b/toy_matmul.err @@ -0,0 +1,2 @@ +/home/cdilgren/loki/.venv/lib/python3.10/site-packages/apex-0.1-py3.10-linux-x86_64.egg/apex/contrib/sparsity/sparse_masklib.py:42: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:83.) + mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m) diff --git a/toy_matmul.out b/toy_matmul.out new file mode 100644 index 0000000..71832df --- /dev/null +++ b/toy_matmul.out @@ -0,0 +1,69 @@ +Loading cuda/12.3.0/gcc/11.3.0/x86_64 + Loading requirement: gcc/11.3.0 +Could not find permutation search CUDA kernels, falling back to CPU path +[ASP][Info] permutation_search_kernels can be imported. +[ASP][Warning] torchvision cannot be imported. +[set_permutation_params_from_asp] Set permutation needed parameters + Sparse parameter names: ['linear:weight'] + All parameter names: [':linear.weight', ':linear.bias', 'linear:weight', 'linear:bias'] +[set_identical_seed] Set the identical seed: 1 for all GPUs to make sure the same results generated in permutation search + +[permute_model] Permuting the model +[build_fx_graph] The torch version is: 2.1.0+cu121, version major is: 2, version minor is: 1, version minimum is: 0+cu121 +[build_fx_graph] The Torch.FX is supported. + +[build_fx_graph] Print the model structure with pure PyTorch function +SimpleModel( + (linear): Linear(in_features=4096, out_features=4096, bias=True) +) + +[print_raw_fx_graph] Print the intermediate representation (IR) with Torch.FX +graph(): + %x : [num_users=1] = placeholder[target=x] + %linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {}) + return linear + +[print_raw_fx_graph] Print the intermediate representation (IR) with Torch.FX in a table format +[print_raw_fx_graph][Warning] 'print_tabular' relies on the library `tabulate`; run `pip install tabulate` to install it. + +[build_fx_graph] Build the module name and type dictionary +[build_fx_graph] module_name: , module type: +[build_fx_graph] module_name: linear, module type: + +[build_fx_graph] Print the children and parents relationship for each layer +[build_fx_graph] This is the 'input' node: x +[build_fx_graph] This is the 'call_module' node: linear, its parent list: ['x'], its children list: ['output'], its type: torch.nn.modules.linear.Linear +[build_fx_graph] This is the 'output' node: output + +[init_permutation_flags] Initialize the permutation flags for each node according to module type and parameters +Initializing node linear of type torch.nn.modules.linear.Linear: {'parents': ['x'], 'children': ['output'], 'fx_op': 'call_module', 'module_type': 'torch.nn.modules.linear.Linear', 'groups_param': 'None', 'C_param': '4096', 'K_param': '4096'} + Initialized node linear of type torch.nn.modules.linear.Linear: {'parents': ['x'], 'children': ['output'], 'fx_op': 'call_module', 'module_type': 'torch.nn.modules.linear.Linear', 'groups_param': 'None', 'C_param': '4096', 'K_param': '4096', 'C_permutable': True, 'K_permutable': True, 'K_passthru': False, 'is_real': True, 'C_permuted': False, 'K_permuted': False, 'sibling_group_id': None, 'coparent_group_id': None} + +[find_real_parents] Find the real parents for each node according to the whole network graph built with Torch.FX +[find_real_parents] linear has 0 real parents: set() + +[find_real_children] Find the real children for each node according to the whole network graph built with Torch.FX +[find_real_children] node_name: 'linear', children: ['output'] +[find_real_children] linear has 0 real children: set() +[make_sibling_coparent_groups] +New sibling group 0 with GCD(C) of 4096: ['linear'] +New coparent group 0: {'linear'} +[fixup_concats] +[enforce_dimension_agreement] + linear has no real parents, disabling permutations along C + linear has no real children, disabling permutations along K +Making a pass at propagating permutation flags + node linear has poisoned the sibling group of ['linear']: {'parents': ['x'], 'children': ['output'], 'fx_op': 'call_module', 'module_type': 'torch.nn.modules.linear.Linear', 'groups_param': 'None', 'C_param': '4096', 'K_param': '4096', 'C_permutable': False, 'K_permutable': False, 'K_passthru': False, 'is_real': True, 'C_permuted': False, 'K_permuted': False, 'sibling_group_id': 0, 'coparent_group_id': 0, 'real_parents': [], 'real_children': []} + node linear has poisoned the coparent group of {'linear'}: {'parents': ['x'], 'children': ['output'], 'fx_op': 'call_module', 'module_type': 'torch.nn.modules.linear.Linear', 'groups_param': 'None', 'C_param': '4096', 'K_param': '4096', 'C_permutable': False, 'K_permutable': False, 'K_passthru': False, 'is_real': True, 'C_permuted': False, 'K_permuted': False, 'sibling_group_id': 0, 'coparent_group_id': 0, 'real_parents': [], 'real_children': []} +Skipping permutation for sibling group 0 since it does not allow permutations along C + +[permute_model] Take 0.0000 seconds to finish search_for_good_permutation function. +[check_graph_for_unpermuted_nodes] found nodes that missed permutations along 0 dimensions. + +[compute_sparse_masks] permuted the model. +[compute_sparse_masks] Take 0.0021 seconds to find and apply permutations. +[ASP] Enabled 50.00% sparsity for linear::weight of size=torch.Size([4096, 4096]) and type=torch.float32 with magnitude tensor(91732.6875, device='cuda:0') +Input size: 64x4096, Output size: 4096 +Dense time: 0.279 ms +Sparse time: 0.240 ms +Speedup: 1.16x diff --git a/toy_matmul.py b/toy_matmul.py new file mode 100644 index 0000000..def277a --- /dev/null +++ b/toy_matmul.py @@ -0,0 +1,131 @@ +import torch +import time +import numpy as np + +# Try to import APEX for sparse operations +try: + from apex.contrib.sparsity import ASP +except ImportError: + print("NVIDIA Apex not found. Please install it using: pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" 'git+https://github.com/NVIDIA/apex.git'") + exit(1) + +# Create a simple model with one linear layer +class SimpleModel(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + return self.linear(x) + +def benchmark_sparsity(): + # Parameters + in_features = 4096 + out_features = 4096 + batch_size = 64 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device.type != "cuda": + print("CUDA is not available. This benchmark requires an NVIDIA GPU.") + exit(1) + + # Create models + dense_model = SimpleModel(in_features, out_features).to(device) + sparse_model = SimpleModel(in_features, out_features).to(device) + + # ASP.prune_trained_model needs an optimizer for some reason? + optimizer = torch.optim.SGD(sparse_model.parameters(), lr=0.01) + + # Apply 2:4 sparsity to the sparse model + ASP.prune_trained_model(sparse_model, optimizer) + + # Create input tensor + x = torch.randn(batch_size, in_features, device=device) + + # Warmup + for _ in range(10): + dense_model(x) + sparse_model(x) + + # Benchmark dense model + torch.cuda.synchronize() + start_time = time.time() + for _ in range(1000): + dense_output = dense_model(x) + torch.cuda.synchronize() + dense_time = (time.time() - start_time) / 1000 + +# # Calculate FLOPS for dense +# # For a linear layer: 2 * batch_size * in_features * out_features +# dense_flops = 2 * batch_size * in_features * out_features +# dense_tflops = dense_flops / dense_time / 1e12 + + # Benchmark sparse model + torch.cuda.synchronize() + start_time = time.time() + for _ in range(1000): + sparse_output = sparse_model(x) + torch.cuda.synchronize() + sparse_time = (time.time() - start_time) / 1000 + +# # Calculate FLOPS for sparse (50% of dense for 2:4 sparsity) +# sparse_flops = dense_flops * 0.5 +# sparse_tflops = sparse_flops / sparse_time / 1e12 + +# # Check that outputs are close +# assert torch.allclose(dense_output, sparse_output, rtol=1e-2, atol=1e-2), "Outputs are not close!" + + # Print results + print(f"Input size: {batch_size}x{in_features}, Output size: {out_features}") + print(f"Dense time: {dense_time*1000:.3f} ms") #, {dense_tflops:.2f} TFLOPS") + print(f"Sparse time: {sparse_time*1000:.3f} ms") #, {sparse_tflops:.2f} TFLOPS") + print(f"Speedup: {dense_time/sparse_time:.2f}x") + +# # Try to get energy measurements +# try: +# import pynvml +# pynvml.nvmlInit() +# handle = pynvml.nvmlDeviceGetHandleByIndex(0) + +# # Dense energy +# torch.cuda.synchronize() +# pynvml.nvmlDeviceResetApplicationsClocks(handle) +# start_time = time.time() +# start_power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 # W + +# for _ in range(100): +# dense_model(x) +# torch.cuda.synchronize() + +# end_time = time.time() +# end_power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 # W +# dense_duration = end_time - start_time +# dense_power = (start_power + end_power) / 2 + +# dense_energy = dense_power * dense_duration + +# # Sparse energy +# torch.cuda.synchronize() +# pynvml.nvmlDeviceResetApplicationsClocks(handle) +# start_time = time.time() +# start_power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 # W + +# for _ in range(100): +# sparse_model(x) +# torch.cuda.synchronize() + +# end_time = time.time() +# end_power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 # W +# sparse_duration = end_time - start_time +# sparse_power = (start_power + end_power) / 2 +# sparse_energy = sparse_power * sparse_duration + +# print(f"Dense power: {dense_power:.2f} W, Energy: {dense_energy:.2f} J") +# print(f"Sparse power: {sparse_power:.2f} W, Energy: {sparse_energy:.2f} J") +# print(f"Energy savings: {dense_energy/sparse_energy:.2f}x") + +# pynvml.nvmlShutdown() +# except: +# print("Energy measurements not available. Install pynvml for energy measurements.") + +if __name__ == "__main__": + benchmark_sparsity() diff --git a/toy_matmul.sh b/toy_matmul.sh new file mode 100755 index 0000000..024be0d --- /dev/null +++ b/toy_matmul.sh @@ -0,0 +1,25 @@ +#!/bin/bash +#SBATCH --job-name=loki_evaluate_tasks # Job name +#SBATCH --output=loki_evaluate_tasks_%j.out # Output file (%j expands to jobID) +#SBATCH --error=loki_evaluate_tasks_%j.err # Error file (%j expands to jobID) +#SBATCH --ntasks=1 # Number of tasks +#SBATCH --cpus-per-task=4 # CPU cores per task +#SBATCH --mem=32G # Memory requirement +#SBATCH --time=00:04:00 # Time limit (HH:MM:SS) +#SBATCH --partition=gpu # Partition/queue name +#SBATCH --account=cmsc828-class +#SBATCH --gpus=a100:1 # Request specific GPU type +#SBATCH --mail-type=BEGIN,END +#SBATCH --mail-user=cdilgren@umd.edu + +# Load necessary modules (you may need to adjust these for your cluster) +module purge +module load cuda/12.3.0/gcc/11.3.0/x86_64 + +# Activate virtual environment +source /home/cdilgren/loki/.venv/bin/activate + +# Run the Python script with your arguments +python toy_matmul.py + +# End of script From 4ff36773322336ca29811af9702df99b9d81b7bc Mon Sep 17 00:00:00 2001 From: Sukriti Paul Date: Mon, 12 May 2025 07:11:20 -0400 Subject: [PATCH 2/4] Implement naive sparsity of QK weight tensors (pruning) --- README.md | 105 +--- methods/pca_topk/attention_benchmark_apex.py | 518 +++++++++++++++++++ methods/pca_topk/sparsity_utils.py | 280 ++++++++++ requirements_merged.txt | 72 +++ test_attention_benchmark_fixed.py | 299 +++++++++++ 5 files changed, 1192 insertions(+), 82 deletions(-) create mode 100644 methods/pca_topk/attention_benchmark_apex.py create mode 100644 methods/pca_topk/sparsity_utils.py create mode 100644 requirements_merged.txt create mode 100644 test_attention_benchmark_fixed.py diff --git a/README.md b/README.md index 5102027..a130f0e 100644 --- a/README.md +++ b/README.md @@ -1,96 +1,37 @@ -# Loki +# Loki (Naive Sparsity) -This repository contains the code related to the experiments in the paper [Loki: Low-Rank Keys for Efficient Sparse Attention](https://arxiv.org/abs/2406.02542). -We provide the code to compute the PCA of the keys for various models, baseline method implementations and kernels for Loki used in the paper, along with scripts to evaluate the methods on perplexity evaluation and downstream tasks. -## Installation -You need to install the requirements as follows: +## Attention Query-Key Sparsity Mode -``` -pip install -r requirements.txt -``` - -Note: The code requires specific versions of the huggingface transformers library present in the requirements.txt file. It will not work with other versions. - -## Usage - -#### Compute PCA of keys for a model -Say you want to compute the PCA transform for the keys of Llama-2-7b model. You can do this by following the steps below: - -- Run perplexity evaluation on the model on a target dataset to save the keys, queries and values tensors. - ```bash - # The --use-axonn flag is optional and is used to shard the model over multiple GPUs using AxoNN - - python -u evaluate_tasks.py --sequence-length 4096 --model-id meta-llama/Llama-2-7b-hf --model-type llama --dataset wikitext-valid --save-tensors --tensor-dir --use-topk --top-k 1 [--use-axonn] - ``` - List of possible datasets - wikitext-valid, bookcorpus, c4 +When running with `sparsity_type = "attention-query-key"` (enabled via `--run-attention-query-key-sparsity` flag), the implementation: -- Compute the PCA of the generated keys: In the `pca_analysis` directory, run the following command: +1. Applies magnitude-based pruning to keep only the top 50% of weights by magnitude in linear projection layers +2. Optimizes the attention mechanism's query-key operations using torch's native sparse matrix multiplication (more below) +3. The sparse representations are computed once and reused across forward passes - ```bash - python pca.py key - ``` -Verify that the PCA transform are saved in the output directory. Do not modify the subdirectory structure of the output directory as it is used by the downstream tasks evaluation code. -#### Running the ML evaluations -Once the PCA transform is computed, we can run the ML evaluations using Loki. The following command runs the evaluation on the downstream tasks using the PCA transform computed in the previous step: +## Modified Files +- `test_attention_benchmark_fixed.py`: Main benchmark script +- `./methods/pca_topk/attention_benchmark_apex.py`: Implementation of attention mechanisms +- `./methods/pca_topk/sparsity_utils.py`: Utilities for sparse operations -```bash -python -u evaluate_tasks.py \ - --sequence-length 4096 \ - --model-id meta-llama/Llama-2-7b-hf \ - --model-type llama - --use-pca-topk - --top-r <16/32/64> - --top-k <0.125/0.25/0.5> \ - --rotary-type \ - --dataset \ - --transform-dataset \ - [--lm-harness-eval] \ # Flag to evaluate the model on the LM Harness Tasks - [--use-wandb] \ # Optional flag to log the results to wandb - [--use-axonn] # Optional flag to shard the model over multiple GPUs using AxoNN -``` +## Usage Example - -#### Running compute evaluation -To run the compute evaluation, you can use the following command: +Run the benchmark with sparse attention query-key optimization: ```bash -python evaluate_compute.py - +python test_attention_benchmark_fixed.py --orig-pca-dir /cmlscratch/sukriti5/pca_stuff/pca_components --cache-seq-len 3500 --num-gen-steps 10 --top-d 32 --num-heads 16 --run-loki-without-sparsity --run-attention-query-key-sparsity --output-csv ./attention_query_key_results.csv ``` -This will run the attention benchmark with Loki and vanilla attention assuming a Llama2-13B type model and save the results in a `compute_files` directory. - +### Sparse Matrix Multiplication +- For 2D tensors: Using `torch.sparse.mm` for efficient sparse matrix multiplication +- For 3D tensors (batched attention operations): Using `torch.bmm` for batched multiplication +- Precision handling: Converting half-precision tensors to float32 temporarily for sparse operations, as PyTorch's sparse operations don't directly support half-precision +### Sparsity Configuration +- Default sparsity level: 50% (ie keeping top 50% of weights by magnitude for now) +- Implementation: Magnitude-based pruning applied to linear projection layers +- Format: Using `torch.sparse_coo_tensor` to create COO format sparse representations +- One-time sparse conversion: In the `SparseLinear` class, sparse representation is created only once via the `make_sparse()` method. For caching, rhe sparse weight (`self._sparse_weight`) is stored as a class attribute and reused across forward passes +I've followed the same benchmark/attributes as in yours and Connor's benchmark files. diff --git a/methods/pca_topk/attention_benchmark_apex.py b/methods/pca_topk/attention_benchmark_apex.py new file mode 100644 index 0000000..1645f0d --- /dev/null +++ b/methods/pca_topk/attention_benchmark_apex.py @@ -0,0 +1,518 @@ +from typing import Any, Dict, List, Optional, Tuple +from transformers.cache_utils import Cache +import math +import time +import torch +import methods.pca_topk.kernel.pca_topk as G +from methods.common.timers import Timers +from methods.pca_topk.sparsity_utils import SparsityHandler +import json + + +class PcaTopKCache(Cache): + """ + Cache based on PcaTopK mechanism + Note: This class is now just a wrapper around the Cache class from transformers.cache_utils + """ + def __init__(self) -> None: + self.key_cache: List[torch.Tensor] = [] # Stores the reduced keys for each layer + self.value_cache: List[torch.Tensor] = [] + + @torch.no_grad() + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + query_states: torch.Tensor, + layer_idx: int, + is_prompt: bool = False, # Added is_prompt parameter with default value + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + is_prompt (`bool`, `optional`): + Whether this update is for the prompt or for generation. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. + + Return: + A tuple containing the updated key and value states. + """ + if len(self.key_cache) <= layer_idx: + # Empty cache + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + # This is also the prompt iteration so we need all the keys for attention + return self.key_cache[layer_idx], self.value_cache[layer_idx] + else: + # Ensure dimensions match before concatenation + if key_states.dim() != self.key_cache[layer_idx].dim(): + # Adjust dimensions to match + if key_states.dim() < self.key_cache[layer_idx].dim(): + # Add missing dimensions + for _ in range(self.key_cache[layer_idx].dim() - key_states.dim()): + key_states = key_states.unsqueeze(0) + value_states = value_states.unsqueeze(0) + + # Check sequence length dimension + cache_seq_len = self.key_cache[layer_idx].shape[-2] + new_seq_len = key_states.shape[-2] + + # Ensure shapes match in all dimensions except sequence length + if self.key_cache[layer_idx].shape[:-2] != key_states.shape[:-2]: + # Adjust shapes to match + if self.key_cache[layer_idx].shape[0] != key_states.shape[0]: + # Reshape to match number of heads + key_states = key_states.expand(self.key_cache[layer_idx].shape[0], *key_states.shape[1:]) + value_states = value_states.expand(self.value_cache[layer_idx].shape[0], *value_states.shape[1:]) + + # Now concatenate along sequence length dimension + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states, if there is any.""" + return None + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + max_length = self.get_max_length() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reset(self): + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + +class DenseAttentionProj: + """Implements realistic query/key/value projections for a transformer layer.""" + def __init__(self, hidden_size, num_heads, head_dim, dtype=torch.float16, device='cuda', sparsity_handler=None): + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.sparsity_handler = sparsity_handler + + # Initialize Q/K/V projection matrices as real transformers would have + self.q_proj = torch.nn.Linear(hidden_size, num_heads * head_dim, dtype=dtype, device=device) + self.k_proj = torch.nn.Linear(hidden_size, num_heads * head_dim, dtype=dtype, device=device) + self.v_proj = torch.nn.Linear(hidden_size, num_heads * head_dim, dtype=dtype, device=device) + + # Output projection + self.out_proj = torch.nn.Linear(num_heads * head_dim, hidden_size, dtype=dtype, device=device) + + # Apply sparsity to weight matrices if sparsity handler is provided + if self.sparsity_handler and self.sparsity_handler.sparsity_type == "attention-query-key": + self.q_proj = self.sparsity_handler.sparsify_linear_weight(self.q_proj, "q_proj") + self.k_proj = self.sparsity_handler.sparsify_linear_weight(self.k_proj, "k_proj") + # We typically don't sparsify the value projection, but you can uncomment if needed + # self.v_proj = self.sparsity_handler.sparsify_linear_weight(self.v_proj, "v_proj") + + def _shape(self, tensor, seq_len, bsz): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @torch.no_grad() + def forward(self, hidden_states): + """ + Compute Q, K, V projections from input embeddings as in a real transformer + + Args: + hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size) + + Returns: + query_states, key_states, value_states of shapes (batch_size, num_heads, seq_len, head_dim) + """ + bsz, seq_len = hidden_states.shape[0], hidden_states.shape[1] + + # Project to Q, K, V using sparse matrix multiplication if available + if self.sparsity_handler and self.sparsity_handler.sparsity_type == "attention-query-key": + # The weight matrices are already sparsified during initialization + # We just need to use them for the projection + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + else: + # Regular dense projection + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to (batch_size, num_heads, seq_len, head_dim) + query_states = self._shape(q, seq_len, bsz) + key_states = self._shape(k, seq_len, bsz) + value_states = self._shape(v, seq_len, bsz) + + return query_states, key_states, value_states + + @torch.no_grad() + def output_projection(self, attn_output): + """ + Apply output projection to attention output + + Args: + attn_output: Attention output of shape (batch_size, num_heads, seq_len, head_dim) + + Returns: + Output tensor of shape (batch_size, seq_len, hidden_size) + """ + bsz, seq_len = attn_output.shape[0], attn_output.shape[2] + + # Reshape to (batch_size, seq_len, num_heads * head_dim) + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + + # Apply output projection + return self.out_proj(attn_output) + + +def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_layers, timers, + num_gen_steps=2000, use_optimised_gather=False, sparsity_type=None): + import time + torch.set_float32_matmul_precision("highest") + + head_dim = prompt_keys[0].shape[-1] + bs = prompt_keys[0].shape[0] + num_heads = prompt_keys[0].shape[1] + hidden_size = num_heads * head_dim + dtype = prompt_keys[0].dtype + prompt_seq_length = prompt_keys[0].shape[2] + + # Initialize sparsity handler if sparsity type is specified + sparsity_handler = None + if sparsity_type: + sparsity_handler = SparsityHandler(sparsity_type) + + # PCA projection matrix + pca_projection_mat = torch.randn(num_heads, head_dim, head_dim, dtype=dtype, device='cuda') + + # Apply sparsity pattern to the projection matrix if sparsity handler is initialized + if sparsity_handler: + # We've already verified the SparsityHandler is working correctly during setup + pca_projection_mat = sparsity_handler.apply_sparsity_pattern(pca_projection_mat) + # Get sparsity statistics + stats = sparsity_handler.get_sparsity_stats(pca_projection_mat) + + # Initialize dense attention projections for each layer (more realistic) + attention_projections = [ + DenseAttentionProj(hidden_size, num_heads, head_dim, dtype=dtype, sparsity_handler=sparsity_handler) + for _ in range(num_layers) + ] + + # Initial input embedding to start the autoregressive process + input_embedding = torch.rand(bs, 1, hidden_size, device='cuda', dtype=dtype) + + assert use_optimised_gather + if use_optimised_gather: + timers.start('total') + for i in range(num_gen_steps): + for layer in range(num_layers): + # Use the previous output as input (autoregressive) for realistic benchmarking + layer_input = input_embedding + + timers.start('qk-gen') + # Generate query/key/value from the same input embedding (realistic) + query_states, key_states, value_states = attention_projections[layer].forward(layer_input) + timers.stop('qk-gen') + + timers.start('project') + # Apply PCA projection (if using) + projected_key = key_states.squeeze(2).transpose(0, 1).bmm(pca_projection_mat).unsqueeze(2) + projected_query = query_states.squeeze(2).transpose(0, 1).bmm(pca_projection_mat).unsqueeze(2) + timers.stop('project') + + # Note: We no longer apply sparsity here since it's now applied to the weight matrices before projection + + # Ensure projected_key and value_states have the same shape format as the cache + # The cache expects [num_heads, batch_size, seq_len, head_dim] format + if projected_key.shape[0] != num_heads: + # Transpose to match cache format if needed + if projected_key.dim() == 4 and projected_key.shape[1] == bs: + # Already in [num_heads, bs, seq_len, head_dim] format + pass + elif projected_key.dim() == 3: + # Add batch dimension if missing + projected_key = projected_key.unsqueeze(1) + else: + # Reshape to match expected format + projected_key = projected_key.view(num_heads, bs, -1, head_dim) + + # Ensure value_states has the same format as the cache + if value_states.shape[0] != num_heads: + # Transpose to match cache format if needed + if value_states.dim() == 4 and value_states.shape[1] == bs: + # Already in [num_heads, bs, seq_len, head_dim] format + pass + elif value_states.dim() == 3: + # Add batch dimension if missing + value_states = value_states.unsqueeze(1) + else: + # Reshape to match expected format + value_states = value_states.view(num_heads, bs, -1, head_dim) + + timers.start('cache-update') + keys, vals = cache.update(projected_key, value_states, projected_query, layer, False) + timers.stop('cache-update') + + timers.start('qk-matmul-1') + nh, bs, s, r = keys.shape + + # Use sparse matrix multiplication if needed + if sparsity_handler and sparsity_type == "attention-query-key": + attn_weights = sparsity_handler.sparse_matmul( + projected_query.view(nh*bs, 1, r), + keys.view(nh*bs, s, r).transpose(-1,-2) + ) + else: + attn_weights = G.topr_bmv_optimized( + A=projected_query.view(nh*bs, 1, r), + B=keys.view(nh*bs, s, r).transpose(-1,-2), + r=top_r + ) + attn_weights = attn_weights.view(nh, bs, 1, s) + timers.stop('qk-matmul-1') + + # Get top-k keys and top-k values based on the attention scores + timers.start('top-k') + key_states_topk_indices = torch.argsort(attn_weights, dim=-1, descending=True)[:,:,:,:top_k] + timers.stop('top-k') + + timers.start('reshape-0') + key_states_topk_indices= key_states_topk_indices.reshape(-1, key_states_topk_indices.shape[-1]) + timers.stop('reshape-0') + + timers.start('reshape-1') + keys = keys.view(-1, keys.shape[-2] , keys.shape[-1]) + vals = vals.view(-1, vals.shape[-2] , vals.shape[-1]) + timers.stop('reshape-1') + + timers.start('qk-matmul-2') + # Use sparse matrix multiplication if needed + if sparsity_handler and sparsity_type == "attention-query-key": + # For sparse matmul, we need to gather the keys first + gathered_keys = torch.gather( + keys, + 1, + key_states_topk_indices.unsqueeze(-1).expand(-1, -1, keys.size(-1)) + ) + + # Then perform sparse matmul + attn_weights = sparsity_handler.sparse_matmul( + projected_query.reshape(-1, 1, head_dim), + gathered_keys.transpose(-1, -2) + ) / math.sqrt(head_dim) + else: + attn_weights = G.gather_outer_bmv_optimized( + projected_query.reshape(-1, 1, head_dim), + keys.transpose(-1, -2), + key_states_topk_indices, + ) / math.sqrt(head_dim) + timers.stop('qk-matmul-2') + + timers.start('softmax') + attn_weights = torch.softmax(attn_weights.float(), dim=-1).to(dtype) + timers.stop('softmax') + + timers.start('sv-matmul') + attn_output = G.gather_inner_matrix_only_bmv_optimized( + attn_weights, vals, key_states_topk_indices) + timers.stop('sv-matmul') + + # Reshape is handled along with output projection + + # Apply output projection (realistic) - include in reshape-output timing + timers.start('reshape-output') + attn_output = attn_output.view(num_heads, bs, 1, head_dim).transpose(0,1) + output = attention_projections[layer].output_projection(attn_output) + timers.stop('reshape-output') + + # Store this layer's output as input for the next layer or next token + input_embedding = output + + timers.stop('total') + else: + # Fallback implementation with basic operations (non-optimized) + # ...omitted for brevity... + pass + + +def micro_bench_actual_attention(cache, prompt_keys, num_layers, timers, num_gen_steps=2000): + import time + torch.set_float32_matmul_precision("highest") + + head_dim = prompt_keys[0].shape[-1] + bs = prompt_keys[0].shape[0] + num_heads = prompt_keys[0].shape[1] + hidden_size = num_heads * head_dim + dtype = prompt_keys[0].dtype + + # Initialize dense attention projections for each layer (more realistic) + attention_projections = [ + DenseAttentionProj(hidden_size, num_heads, head_dim, dtype=dtype, sparsity_handler=sparsity_handler) + for _ in range(num_layers) + ] + + # Initial input embedding to start the autoregressive process + input_embedding = torch.rand(bs, 1, hidden_size, device='cuda', dtype=dtype) + + timers.start('total') + for i in range(num_gen_steps): + for layer in range(num_layers): + # Use the previous output as input (autoregressive) + layer_input = input_embedding + + timers.start('qk-gen') + # Generate query/key/value from the same input embedding (realistic) + query_states, key_states, value_states = attention_projections[layer].forward(layer_input) + timers.stop('qk-gen') + + timers.start('cache-update') + keys, vals = cache.update(key_states, value_states, query_states, layer, False) + timers.stop('cache-update') + + timers.start('qk-matmul-1') + attn_weights = torch.matmul(query_states, keys.transpose(2, 3)) / math.sqrt(head_dim) + timers.stop('qk-matmul-1') + + timers.start('softmax') + attn_weights = torch.softmax(attn_weights.float(), dim=-1).to(dtype) + timers.stop('softmax') + + timers.start('sv-matmul') + attn_output = torch.matmul(attn_weights, vals) + timers.stop('sv-matmul') + + timers.start('reshape-output') + # Apply output projection (realistic) as part of reshape timing + output = attention_projections[layer].output_projection(attn_output) + timers.stop('reshape-output') + + # Store this layer's output as input for the next layer or next token + input_embedding = output + + timers.stop('total') + +@torch.no_grad() +def benchmark_attention(batch_size=1, + num_heads=32, + num_gen_steps=128, + prompt_length=3072, + topk=256, + topr=32, + num_layers=32, + dtype=torch.float16, + vanilla=True, + pcatopk=True, + sparsity_type=None, + ): + + # If sparsity is requested, verify that the sparse multiplication is available + if sparsity_type: + # Only allow attention-query-key sparsity type + if sparsity_type != "attention-query-key": + print(f"Warning: Sparsity type '{sparsity_type}' is not supported. Only 'attention-query-key' is supported.") + print("Setting sparsity_type to 'attention-query-key'") + sparsity_type = "attention-query-key" + + try: + # Initialize sparsity handler + from methods.pca_topk.sparsity_utils import SparsityHandler + sparsity_handler = SparsityHandler(sparsity_type) + + # Create some test tensors + test_q = torch.rand(2, 1, 128, device='cuda', dtype=dtype) + test_k = torch.rand(2, 10, 128, device='cuda', dtype=dtype) + + # Try the sparse matrix multiplication + _ = sparsity_handler.sparse_matmul(test_q, test_k.transpose(-1, -2)) + + # If we get here, sparse matmul is working + print(f"Verified that sparse matmul for sparsity type '{sparsity_type}' is working.") + except Exception as e: + # If sparse matmul fails, abort the benchmark - don't continue with a fallback + raise RuntimeError(f"Sparse matmul for '{sparsity_type}' failed verification: {e}\n" + f"Cannot proceed with benchmark as sparse optimization is required.") + + head_dim=128 + hidden_size = num_heads * head_dim + + # Generate initial prompt embeddings (more realistic than pure random keys) + prompt_embeddings = [torch.rand(batch_size, prompt_length, hidden_size, device='cuda', dtype=dtype) for _ in range(num_layers)] + + # Create attention projections + # Initialize sparsity_handler to None if not defined (for Loki baseline) + if 'sparsity_handler' not in locals() and sparsity_type is None: + sparsity_handler = None + + attention_projections = [ + DenseAttentionProj(hidden_size, num_heads, head_dim, dtype=dtype, sparsity_handler=sparsity_handler) + for _ in range(num_layers) + ] + + # Create prompt keys, values using proper projections + prompt_keys = [] + prompt_values = [] + + for i in range(num_layers): + q, k, v = attention_projections[i].forward(prompt_embeddings[i]) + prompt_keys.append(k) + prompt_values.append(v) + + times_pca_topk = None + if pcatopk: + print("PCA TOPK Optimized") + for _ in range(10): + cache2 = PcaTopKCache() + for i in range(num_layers): + cache2.update(prompt_keys[i].transpose(0,1).contiguous(), + prompt_values[i].transpose(0,1).contiguous(), + prompt_keys[i].transpose(0,1).contiguous(), i) + timers = Timers() + micro_benchmark_pca_topk(cache2, prompt_keys, topr, topk, + num_gen_steps=num_gen_steps, num_layers=num_layers, + use_optimised_gather=True, timers=timers, + sparsity_type=sparsity_type) + del cache2 + times = timers.get_times() + print(times) + + print("Average time (minus cache updates) is - ") + print(times['total'] - times['cache-update'], " s") + print("==================================") + times_pca_topk = times + + times_vanilla = None + if vanilla: + print("Actual Attention") + for _ in range(10): + cache3= PcaTopKCache() + for i in range(num_layers): + cache3.update(prompt_keys[i], prompt_values[i], prompt_keys[i], i) + timers = Timers() + micro_bench_actual_attention(cache3, prompt_keys, num_layers=num_layers, + num_gen_steps=num_gen_steps, timers=timers) + del cache3 + times = timers.get_times() + print("Average time (minus cache updates) is - ") + print(times['total'] - times['cache-update'], " s") + print(times) + print("==================================") + times_vanilla = times + return times_pca_topk, times_vanilla \ No newline at end of file diff --git a/methods/pca_topk/sparsity_utils.py b/methods/pca_topk/sparsity_utils.py new file mode 100644 index 0000000..5da0d1e --- /dev/null +++ b/methods/pca_topk/sparsity_utils.py @@ -0,0 +1,280 @@ +import torch +import ipdb +import math + + +def tensor_sparsity(tensor): + """Calculate the sparsity (fraction of zeros) in a tensor""" + return (tensor == 0).sum().item() / tensor.numel() + +class SparseLinear(torch.nn.Module): + """Custom Linear layer that uses optimized PyTorch sparse matrix multiplication""" + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), device=device, dtype=dtype)) + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, device=device, dtype=dtype)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + self.is_sparse = False + self._sparse_weight = None + + def reset_parameters(self): + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + torch.nn.init.uniform_(self.bias, -bound, bound) + + def make_sparse(self): + """Convert weight to sparse format using optimized PyTorch sparse operations""" + self.is_sparse = True + + # Create optimized sparse representation using PyTorch + # Get indices of non-zero elements + indices = torch.nonzero(self.weight).t() + + # Get values at these indices + values = self.weight[indices[0], indices[1]] + + # Create sparse tensor in COO format (most efficient for matrix multiplication) + self._sparse_weight = torch.sparse_coo_tensor( + indices, values, self.weight.shape, + device=self.weight.device, + dtype=self.weight.dtype + ).coalesce() # Coalesce to combine duplicate indices for better performance + + def forward(self, input): + if not self.is_sparse: + # Convert to sparse on first use if not already done + self.make_sparse() + self.is_sparse = True + + # Store original dtype for later conversion back + original_dtype = input.dtype + + # Handle input dimensions for sparse matrix multiplication + original_shape = input.shape + need_reshape = len(original_shape) > 2 + + if need_reshape: + # Reshape to 2D for matrix multiplication + input_2d = input.reshape(-1, original_shape[-1]) + else: + input_2d = input + + if input_2d.shape[-1] != self.in_features: + # Transpose if dimensions don't match + input_2d = input_2d.t() + + + if self._sparse_weight is None: + self.make_sparse() + + + if input.dtype == torch.float16 or input.dtype == torch.bfloat16: + input_2d = input_2d.float() + sparse_weight = self._sparse_weight.float() + else: + sparse_weight = self._sparse_weight + + # Use PyTorch's optimized sparse matrix multiplication + output_2d = torch.sparse.mm(sparse_weight, input_2d.t()).t() + + # Convert back to original dtype if needed + if output_2d.dtype != original_dtype: + output_2d = output_2d.to(original_dtype) + + + if self.bias is not None: + output_2d += self.bias + if need_reshape: + output_shape = list(original_shape[:-1]) + [self.out_features] + return output_2d.reshape(*output_shape) + else: + return output_2d + + +class SparsityHandler: + def __init__(self, sparsity_type="attention-query-key"): + self.sparsity_type = sparsity_type + self.density = 0.1 + # For block sparsity in attention-query-key + self.block_size = 16 + self.block_density = 0.1 + self.sparsified_weights = {} + + def apply_sparsity_pattern(self, projection_matrix): + """Apply the selected sparsity pattern to the projection matrix""" + stats = self.get_sparsity_stats(projection_matrix) + if self.sparsity_type == "attention-query-key": + return projection_matrix + else: + return projection_matrix + + def apply_query_key_sparsity(self, matrix): + """Apply block sparsity to query/key matrices after projection""" + if self.sparsity_type != "attention-query-key": + return matrix + + try: + if not matrix.is_contiguous(): + matrix = matrix.contiguous() + + shape = matrix.shape + + flat_matrix = matrix.view(-1) + k = int(flat_matrix.numel() * self.density) + + threshold = torch.kthvalue(torch.abs(flat_matrix), flat_matrix.numel() - k + 1)[0] + + mask = (torch.abs(matrix) >= threshold).to(matrix.dtype) + sparse_matrix = matrix * mask + + return sparse_matrix + + except Exception as e: + return matrix + + def sparsify_linear_weight(self, linear_layer, layer_name=None): + """Convert a standard Linear layer to a SparseLinear layer with sparsified weights""" + if self.sparsity_type != "attention-query-key": + return linear_layer + + if layer_name and layer_name in self.sparsified_weights: + return linear_layer + + try: + sparse_layer = SparseLinear( + in_features=linear_layer.in_features, + out_features=linear_layer.out_features, + bias=linear_layer.bias is not None, + device=linear_layer.weight.device, + dtype=linear_layer.weight.dtype + ) + + sparse_layer.weight.data.copy_(linear_layer.weight.data) + if linear_layer.bias is not None: + sparse_layer.bias.data.copy_(linear_layer.bias.data) + + flat_weight = sparse_layer.weight.view(-1) + sorted_weights = torch.sort(torch.abs(flat_weight), descending=True)[0] + + threshold_idx = int(flat_weight.numel() * 0.5) # 50% density + threshold = sorted_weights[threshold_idx] + + mask = (torch.abs(sparse_layer.weight) >= threshold).to(sparse_layer.weight.dtype) + + sparse_layer.weight.data *= mask + + sparse_layer.make_sparse() + + if layer_name: + self.sparsified_weights[layer_name] = True + + return sparse_layer + except Exception as e: + try: + + weight = linear_layer.weight.data + + + flat_weight = weight.view(-1) + k = int(flat_weight.numel() * 0.5) # 50% density + + threshold = torch.kthvalue(torch.abs(flat_weight), flat_weight.numel() - k + 1)[0] + mask = (torch.abs(weight) >= threshold).to(weight.dtype) + linear_layer.weight.data *= mask + + + if layer_name: + self.sparsified_weights[layer_name] = True + + return linear_layer + except Exception: + + return linear_layer + + def sparse_matmul(self, a, b): + """Perform efficient sparse matrix multiplication using optimized PyTorch sparse operations""" + # Store original dtype for later conversion back + original_dtype = a.dtype + + # Convert to float32 if using half precision, as sparse operations don't support half + if a.dtype == torch.float16 or a.dtype == torch.bfloat16: + a = a.float() + b = b.float() + + # Handle the attention query-key case (3D tensors) + if a.dim() == 3 and b.dim() == 3: + # For query-key attention: [batch, seq_q, dim] @ [batch, dim, seq_k] -> [batch, seq_q, seq_k] + batch_size = a.shape[0] + seq_len_q = a.shape[1] + hidden_dim = a.shape[2] + seq_len_k = b.shape[2] + + # Use torch.bmm for batched matrix multiplication (dense) + result = torch.bmm(a, b) + + # Convert back to original dtype + if original_dtype != result.dtype: + result = result.to(original_dtype) + + return result + + # For 2D case, use sparse matrix multiplication + else: + # Ensure tensors are 2D + if a.dim() > 2: + a_2d = a.reshape(-1, a.shape[-1]) + else: + a_2d = a + + if b.dim() > 2: + b_2d = b.reshape(-1, b.shape[-1]) + else: + b_2d = b + + # Handle dimension mismatch + if a_2d.shape[1] != b_2d.shape[0] and a_2d.shape[1] == b_2d.shape[1]: + b_2d = b_2d.transpose(0, 1) + + # Convert to sparse + indices = torch.nonzero(a_2d).t() + if indices.shape[1] == 0: # Handle all zeros + result_2d = torch.zeros(a_2d.shape[0], b_2d.shape[1], device=a.device, dtype=a.dtype) + else: + values = a_2d[indices[0], indices[1]] + sparse_a = torch.sparse_coo_tensor( + indices, values, a_2d.shape, + device=a_2d.device, + dtype=a_2d.dtype + ).coalesce() + + # Perform sparse matrix multiplication + result_2d = torch.sparse.mm(sparse_a, b_2d) + + # Convert back to original dtype + if original_dtype != result_2d.dtype: + result_2d = result_2d.to(original_dtype) + + + if a.dim() > 2 or b.dim() > 2: + target_shape = list(a.shape[:-1]) + [b.shape[-1] if b.dim() > 1 else 1] + return result_2d.view(*target_shape) + else: + return result_2d + + def get_sparsity_stats(self, tensor): + """Get sparsity statistics for the tensor""" + total_elements = tensor.numel() + zero_elements = torch.sum(tensor == 0).item() + sparsity_ratio = zero_elements / total_elements + return { + 'total_elements': total_elements, + 'zero_elements': zero_elements, + 'sparsity_ratio': sparsity_ratio + } diff --git a/requirements_merged.txt b/requirements_merged.txt new file mode 100644 index 0000000..0a87024 --- /dev/null +++ b/requirements_merged.txt @@ -0,0 +1,72 @@ +# Base requirements (excluding conflicting packages) +aiohappyeyeballs==2.6.1 +aiohttp==3.11.16 +aiosignal==1.3.2 +appdirs==1.4.4 +attrs==25.3.0 +certifi==2025.1.31 +charset-normalizer==3.4.1 +click==8.1.8 +datasets==2.17.1 +dill==0.3.8 +docker-pycreds==0.4.0 +filelock==3.18.0 +frozenlist==1.6.0 +fsspec==2023.10.0 +gitdb==4.0.12 +GitPython==3.1.44 +huggingface-hub>=0.24.0 +idna==3.10 +Jinja2==3.1.6 +MarkupSafe==3.0.2 +mpmath==1.3.0 +multidict==6.4.3 +multiprocess==0.70.16 +networkx==3.4.2 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.18.1 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvtx-cu12==12.1.105 +packaging==25.0 +pandas==2.2.3 +propcache==0.3.1 +protobuf==4.25.6 +psutil==7.0.0 +pyarrow==19.0.1 +pyarrow-hotfix==0.6 +python-dateutil==2.9.0.post0 +pytz==2025.2 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.3 +safetensors==0.4.3 +sentry-sdk==2.26.1 +setproctitle==1.3.5 +six==1.17.0 +smmap==5.0.2 +sympy==1.13.3 +tokenizers>=0.21 +torch==2.1.0 +tqdm==4.67.1 +triton==2.1.0 +typing_extensions==4.13.2 +tzdata==2025.2 +urllib3==2.4.0 +wandb==0.16.3 +xxhash==3.5.0 +yarl==1.20.0 + +# Using newer versions for conflicting packages +transformers==4.47.0 +accelerate==1.6.0 +peft==0.15.2 +lm_eval==0.4.8 diff --git a/test_attention_benchmark_fixed.py b/test_attention_benchmark_fixed.py new file mode 100644 index 0000000..7a1549a --- /dev/null +++ b/test_attention_benchmark_fixed.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +"""Interactive test script for attention_benchmark_apex.py +This script allows testing a single case of the benchmark with specified parameters. +""" + +from methods.pca_topk.attention_benchmark_apex import benchmark_attention +import argparse +import json +import torch +import os +import gc +import sys +import traceback + +# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +# Set thread limits to avoid OpenBLAS thread creation errors +os.environ["OMP_NUM_THREADS"] = "1" # Set OpenMP threads +os.environ["MKL_NUM_THREADS"] = "1" # Set MKL threads +os.environ["NUMEXPR_NUM_THREADS"] = "1" # Set numexpr threads +os.environ["OPENBLAS_NUM_THREADS"] = "1" # Set OpenBLAS threads + +def free_gpu_memory(): + """Force garbage collection and empty CUDA cache to free up GPU memory""" + gc.collect() + torch.cuda.empty_cache() + print(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / (1024**3):.2f} GB used, " + f"{torch.cuda.memory_reserved() / (1024**3):.2f} GB reserved") + +def parse_args(): + parser = argparse.ArgumentParser(description="Test PCA TopK Attention Benchmark") + parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on") + parser.add_argument("--orig-pca-dir", type=str, required=True, + help="Directory containing original PCA components") + parser.add_argument("--num-gen-steps", type=int, default=100, + help="Number of generation steps") + parser.add_argument("--cache-seq-len", type=int, default=2048, + help="Cache sequence length (prompt length)") + parser.add_argument("--top-d", type=int, default=64, + help="Top-d parameter (related to topk)") + parser.add_argument("--topk", type=int, default=None, + help="Direct topk parameter (overrides calculation from top-d if provided)") + parser.add_argument("--topr", type=int, default=32, + help="Direct topr parameter") + parser.add_argument("--output-csv", type=str, default=None, + help="Output CSV file path") + parser.add_argument("--batch-size", type=int, default=1, + help="Batch size") + parser.add_argument("--num-heads", type=int, default=32, + help="Number of attention heads") + parser.add_argument("--num-layers", type=int, default=32, + help="Number of transformer layers") + parser.add_argument("--run-loki-without-sparsity", action="store_true", default=True, + help="Whether to run Loki without sparsity benchmark") + parser.add_argument("--run-attention-query-key-sparsity", action="store_true", default=True, + help="Whether to run Loki with attention-query-key sparsity benchmark") + parser.add_argument("--sparsity-type", type=str, default="attention-query-key", + help="Sparsity type to use: attention-query-key") + + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + + # Set device + device = torch.device(args.device) + torch.cuda.set_device(device) + + # Print GPU information + print(f"CUDA available: {torch.cuda.is_available()}") + print(f"GPU count: {torch.cuda.device_count()}") + print(f"GPU name: {torch.cuda.get_device_name(0)}") + print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB") + + # Create compute_files directory for JSON outputs + os.makedirs("compute_files", exist_ok=True) + + # Ensure output directory exists if output_csv is provided + if args.output_csv: + os.makedirs(os.path.dirname(args.output_csv), exist_ok=True) + + # Free memory before benchmark + free_gpu_memory() + + # Calculate topk and topr based on top-d or use direct values if provided + if args.topk is None: + topk = args.cache_seq_len // args.top_d + else: + topk = args.topk + + topr = args.topr # Use the provided topr value + + print(f"Running benchmark with parameters:") + print(f" Prompt length: {args.cache_seq_len}") + print(f" Gen steps: {args.num_gen_steps}") + print(f" Batch size: {args.batch_size}") + if args.topk is None: + print(f" Top-d: {args.top_d} (resulting in topk={topk})") + else: + print(f" Topk: {topk} (directly specified)") + print(f" Topr: {topr}") + + with torch.no_grad(): + try: + # Create file paths for results + sparse_filename = f"compute_files/prompt_{args.cache_seq_len}_gen_{args.num_gen_steps}_topk_{topk}_topr_{topr}_sparse.json" + loki_without_sparsity_filename = f"compute_files/prompt_{args.cache_seq_len}_gen_{args.num_gen_steps}_loki_without_sparsity_topk_{topk}_topr_{topr}.json" + attention_query_key_filename = f"compute_files/prompt_{args.cache_seq_len}_gen_{args.num_gen_steps}_attention_query_key_topk_{topk}_topr_{topr}.json" + + # Variable to track if we have sparse results + times_sparse = None + + # First run: Loki + Sparsity (nvidia-apex-2-4) + # print(f"\nRunning Loki + Sparsity (nvidia-apex-2-4) benchmark...") + # try: + # times_sparse, _ = benchmark_attention( + # prompt_length=args.cache_seq_len, + # num_gen_steps=args.num_gen_steps, + # batch_size=args.batch_size, + # num_heads=args.num_heads, + # num_layers=args.num_layers, + # topk=topk, + # topr=topr, + # vanilla=False, + # pcatopk=True, + # sparsity_type="nvidia-apex-2-4", + # dtype=torch.float16 + # ) + # + # # Save sparse results + # print(f"Saving to {sparse_filename}") + # with open(sparse_filename, "w") as f: + # json.dump(times_sparse, f, indent=2) + # + # # Print sparse results + # print("\nLoki + Sparsity Times:") + # for key, value in times_sparse.items(): + # print(f" {key}: {value:.6f} s") + # print(f"Net time (minus cache updates): {times_sparse.get('total', 0) - times_sparse.get('cache-update', 0):.6f} s") + # except Exception as e: + # print(f"Error in Loki + Sparsity benchmark: {str(e)}") + # traceback.print_exc() + + # Set times_sparse to None since we're not running this benchmark + times_sparse = None + + # Free memory before vanilla benchmark + free_gpu_memory() + + # Variable to track if we have vanilla results + times_loki_without_sparsity = None + + # Second run: Loki without Sparsity + if args.run_loki_without_sparsity: + print(f"\nRunning Loki without Sparsity benchmark...") + try: + times_loki_without_sparsity, _ = benchmark_attention( # Note: we're getting the first return value now + prompt_length=args.cache_seq_len, + num_gen_steps=args.num_gen_steps, + batch_size=args.batch_size, + num_heads=args.num_heads, + num_layers=args.num_layers, + topk=topk, + topr=topr, + vanilla=False, # Use PCA TopK, not vanilla attention + pcatopk=True, # Use PCA TopK + # No sparsity_type parameter, so no 2:4 sparsity + dtype=torch.float16 + ) + + # Add missing attributes to vanilla results for consistent comparison + # These operations don't exist in vanilla but are in sparse + if times_loki_without_sparsity: + missing_attrs = ['project', 'top-k', 'reshape-0', 'reshape-1', 'qk-matmul-2'] + for attr in missing_attrs: + if attr not in times_loki_without_sparsity: + times_loki_without_sparsity[attr] = 0.0 + else: + print("Error: No timing results returned for Loki without Sparsity benchmark") + + # Save vanilla results + print(f"Saving to {loki_without_sparsity_filename}") + with open(loki_without_sparsity_filename, "w") as f: + json.dump(times_loki_without_sparsity, f, indent=2) + + # Print vanilla results + print("\nLoki without Sparsity Times:") + for key, value in times_loki_without_sparsity.items(): + print(f" {key}: {value:.6f} s") + print(f"Net time (minus cache updates): {times_loki_without_sparsity.get('total', 0) - times_loki_without_sparsity.get('cache-update', 0):.6f} s") + except Exception as e: + print(f"Error in Loki without Sparsity benchmark: {str(e)}") + traceback.print_exc() + + # Third run: Loki + Attention Query/Key Sparsity + if args.run_attention_query_key_sparsity: + print(f"\nRunning Loki + Attention Query/Key Sparsity benchmark...") + try: + times_attention_query_key, _ = benchmark_attention( + prompt_length=args.cache_seq_len, + num_gen_steps=args.num_gen_steps, + batch_size=args.batch_size, + num_heads=args.num_heads, + num_layers=args.num_layers, + topk=topk, + topr=topr, + vanilla=False, + pcatopk=True, + sparsity_type="attention-query-key", + dtype=torch.float16 + ) + + # Add missing attributes if needed for consistent comparison + if times_attention_query_key: + # Check if query-key-sparsity is in the times + if 'query-key-sparsity' not in times_attention_query_key: + times_attention_query_key['query-key-sparsity'] = 0.0 + else: + print("Error: No timing results returned for Loki + Attention Query/Key Sparsity benchmark") + + # Save attention query/key results + print(f"Saving to {attention_query_key_filename}") + with open(attention_query_key_filename, "w") as f: + json.dump(times_attention_query_key, f, indent=2) + + # Print attention query/key results + print("\nLoki + Attention Query/Key Sparsity Times:") + for key, value in times_attention_query_key.items(): + print(f" {key}: {value:.6f} s") + print(f"Net time (minus cache updates): {times_attention_query_key.get('total', 0) - times_attention_query_key.get('cache-update', 0):.6f} s") + except Exception as e: + print(f"Error in Loki + Attention Query/Key Sparsity benchmark: {str(e)}") + traceback.print_exc() + + # Save to CSV if requested + if args.output_csv: + import csv + with open(args.output_csv, 'w', newline='') as csvfile: + fieldnames = ['method', 'prompt_length', 'gen_steps', 'top_d', 'total_time', 'cache_update_time', 'net_time'] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + + # Track if we have any results to write + wrote_results = False + + # Try to write Loki + Sparsity row if available + # if times_sparse: + # writer.writerow({ + # 'method': 'loki_sparsity', + # 'prompt_length': args.cache_seq_len, + # 'gen_steps': args.num_gen_steps, + # 'top_d': args.top_d, + # 'total_time': times_sparse.get('total', 0), + # 'cache_update_time': times_sparse.get('cache-update', 0), + # 'net_time': times_sparse.get('total', 0) - times_sparse.get('cache-update', 0) + # }) + # wrote_results = True + + # Vanilla Loki row if available + if args.run_loki_without_sparsity and times_loki_without_sparsity: + writer.writerow({ + 'method': 'loki_without_sparsity', + 'prompt_length': args.cache_seq_len, + 'gen_steps': args.num_gen_steps, + 'top_d': args.top_d, + 'total_time': times_loki_without_sparsity.get('total', 0), + 'cache_update_time': times_loki_without_sparsity.get('cache-update', 0), + 'net_time': times_loki_without_sparsity.get('total', 0) - times_loki_without_sparsity.get('cache-update', 0) + }) + wrote_results = True + + # Attention Query/Key Sparsity row if available + if args.run_attention_query_key_sparsity and times_attention_query_key: + writer.writerow({ + 'method': 'loki_attention_query_key_sparsity', + 'prompt_length': args.cache_seq_len, + 'gen_steps': args.num_gen_steps, + 'top_d': args.top_d, + 'total_time': times_attention_query_key.get('total', 0), + 'cache_update_time': times_attention_query_key.get('cache-update', 0), + 'net_time': times_attention_query_key.get('total', 0) - times_attention_query_key.get('cache-update', 0) + }) + wrote_results = True + + if wrote_results: + print(f"Results saved to CSV: {args.output_csv}") + else: + print(f"Warning: No results were written to CSV file.") + + except Exception as e: + print(f"Error in benchmark: {str(e)}") + traceback.print_exc() + + # Final memory cleanup + free_gpu_memory() + + print("All benchmarks completed!") From 40636e983308dbe887224848d4ee61eae4840bad Mon Sep 17 00:00:00 2001 From: Sukriti Paul Date: Mon, 12 May 2025 15:01:58 -0400 Subject: [PATCH 3/4] Updated eval_compute.py --- README.md | 7 ++ evaluate_compute.py | 125 ++++++++++++++++--- methods/pca_topk/attention_benchmark_apex.py | 13 +- test_attention_benchmark_fixed.py | 61 +++++++++ 4 files changed, 186 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index a130f0e..fd2b739 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,15 @@ When running with `sparsity_type = "attention-query-key"` (enabled via `--run-at ## Usage Example +Benchmark for vanilla vs loki vs sparse naive +```bash +python evaluate_compute.py + +``` + Run the benchmark with sparse attention query-key optimization: + ```bash python test_attention_benchmark_fixed.py --orig-pca-dir /cmlscratch/sukriti5/pca_stuff/pca_components --cache-seq-len 3500 --num-gen-steps 10 --top-d 32 --num-heads 16 --run-loki-without-sparsity --run-attention-query-key-sparsity --output-csv ./attention_query_key_results.csv ``` diff --git a/evaluate_compute.py b/evaluate_compute.py index bd9cedd..7fb113a 100644 --- a/evaluate_compute.py +++ b/evaluate_compute.py @@ -1,26 +1,119 @@ -from methods.pca_topk.attention_benchmark import benchmark_attention +from methods.pca_topk.attention_benchmark_apex import benchmark_attention import json import torch +import os +import gc +import sys +import traceback - -# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False +os.environ["OMP_NUM_THREADS"] = "1" # Set OpenMP threads +os.environ["MKL_NUM_THREADS"] = "1" # Set MKL threads +os.environ["NUMEXPR_NUM_THREADS"] = "1" # Set numexpr threads +os.environ["OPENBLAS_NUM_THREADS"] = "1" # Set OpenBLAS threads + +def free_gpu_memory(): + gc.collect() + torch.cuda.empty_cache() + print(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / (1024**3):.2f} GB used, " + f"{torch.cuda.memory_reserved() / (1024**3):.2f} GB reserved") + if __name__ == "__main__": + os.makedirs("compute_files", exist_ok=True) + print(f"CUDA available: {torch.cuda.is_available()}") + print(f"GPU count: {torch.cuda.device_count()}") + print(f"GPU name: {torch.cuda.get_device_name(0)}") + print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB") + with torch.no_grad(): - prompt_length = 3500 - for num_gen_steps in [512]: - for topk in [2, 4, 8]: - for topr in [2, 4, 8]: - print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size={16}, topk={topk} and topr={topr}") - times_pca_topk, _ = benchmark_attention(prompt_length=prompt_length, num_gen_steps=num_gen_steps, batch_size=16, topk=prompt_length // topk, topr=128 // topr, vanilla=False) - #with open(f"prompt_{prompt_length}_gen_{num_gen_steps}_pca_topk_opt_first_matmul.json", "w") as f: - with open(f"compute_files/prompt_{prompt_length}_gen_{num_gen_steps}_topk_{topk}_topr_{topr}.json", "w") as f: - json.dump(times_pca_topk, f, indent=2) + for prompt_length in [512, 1024, 2048]: + for num_gen_steps in [64, 128, 256]: + # 1. Vanilla Attention (without Loki) + print(f"\nRunning Vanilla Attention benchmark...") + print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size=16") + try: + free_gpu_memory() + _, times_vanilla = benchmark_attention( + prompt_length=prompt_length, + num_gen_steps=num_gen_steps, + batch_size=16, + vanilla=True, + pcatopk=False, + dtype=torch.float16 + ) + vanilla_filename = f"compute_files/vanilla_prompt_{prompt_length}_gen_{num_gen_steps}.json" + print(f"Saving to {vanilla_filename}") + with open(vanilla_filename, "w") as f: + json.dump(times_vanilla, f, indent=2) + print("Vanilla Attention Times:") + for key, value in times_vanilla.items(): + print(f" {key}: {value:.6f} s") + print(f"Net time (minus cache updates): {times_vanilla.get('total', 0) - times_vanilla.get('cache-update', 0):.6f} s") + except Exception as e: + print(f"Error in Vanilla benchmark: {str(e)}") + traceback.print_exc() + + for topk in [4, 8]: + for topr in [4]: + # 2. Loki without sparsity + print(f"\nRunning Loki without Sparsity benchmark...") + print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size=16, topk={topk} and topr={topr}") + try: + free_gpu_memory() + times_loki, _ = benchmark_attention( + prompt_length=prompt_length, + num_gen_steps=num_gen_steps, + batch_size=16, + topk=prompt_length // topk, + topr=128 // topr, + vanilla=False, + pcatopk=True, + # No sparsity_type parameter + dtype=torch.float16 + ) + loki_filename = f"compute_files/loki_prompt_{prompt_length}_gen_{num_gen_steps}_topk_{topk}_topr_{topr}.json" + print(f"Saving to {loki_filename}") + with open(loki_filename, "w") as f: + json.dump(times_loki, f, indent=2) + print("Loki without Sparsity Times:") + for key, value in times_loki.items(): + print(f" {key}: {value:.6f} s") + print(f"Net time (minus cache updates): {times_loki.get('total', 0) - times_loki.get('cache-update', 0):.6f} s") + except Exception as e: + print(f"Error in Loki benchmark: {str(e)}") + traceback.print_exc() + + # 3. Loki with attention-query-key sparsity + print(f"\nRunning Loki with Attention-Query-Key Sparsity benchmark...") + print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size=16, topk={topk} and topr={topr}") + try: + free_gpu_memory() + times_sparsity, _ = benchmark_attention( + prompt_length=prompt_length, + num_gen_steps=num_gen_steps, + batch_size=16, + topk=prompt_length // topk, + topr=128 // topr, + vanilla=False, + pcatopk=True, + sparsity_type="attention-query-key", + dtype=torch.float16 + ) + sparsity_filename = f"compute_files/attention_query_key_prompt_{prompt_length}_gen_{num_gen_steps}_topk_{topk}_topr_{topr}.json" + print(f"Saving to {sparsity_filename}") + with open(sparsity_filename, "w") as f: + json.dump(times_sparsity, f, indent=2) + print("Loki with Attention-Query-Key Sparsity Times:") + for key, value in times_sparsity.items(): + print(f" {key}: {value:.6f} s") + print(f"Net time (minus cache updates): {times_sparsity.get('total', 0) - times_sparsity.get('cache-update', 0):.6f} s") + except Exception as e: + print(f"Error in Attention-Query-Key Sparsity benchmark: {str(e)}") + traceback.print_exc() + + print("\nAll benchmarks completed!") + free_gpu_memory() - _, times_vanilla = benchmark_attention(prompt_length=prompt_length, num_gen_steps=num_gen_steps, batch_size=16, topk=prompt_length // topk, topr=128 // topr, pcatopk=False) - with open(f"compute_files/prompt_{prompt_length}_gen_{num_gen_steps}_vanilla.json", "w") as f: - json.dump(times_vanilla, f, indent=2) - diff --git a/methods/pca_topk/attention_benchmark_apex.py b/methods/pca_topk/attention_benchmark_apex.py index 1645f0d..a1692ea 100644 --- a/methods/pca_topk/attention_benchmark_apex.py +++ b/methods/pca_topk/attention_benchmark_apex.py @@ -348,12 +348,10 @@ def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_layers, timer timers.stop('total') else: - # Fallback implementation with basic operations (non-optimized) - # ...omitted for brevity... pass -def micro_bench_actual_attention(cache, prompt_keys, num_layers, timers, num_gen_steps=2000): +def micro_bench_actual_attention(cache, prompt_keys, num_layers, timers, num_gen_steps=2000, sparsity_type=None): import time torch.set_float32_matmul_precision("highest") @@ -362,6 +360,12 @@ def micro_bench_actual_attention(cache, prompt_keys, num_layers, timers, num_gen num_heads = prompt_keys[0].shape[1] hidden_size = num_heads * head_dim dtype = prompt_keys[0].dtype + + # Initialize sparsity handler if needed + sparsity_handler = None + if sparsity_type: + from methods.pca_topk.sparsity_utils import SparsityHandler + sparsity_handler = SparsityHandler(sparsity_type) # Initialize dense attention projections for each layer (more realistic) attention_projections = [ @@ -507,7 +511,8 @@ def benchmark_attention(batch_size=1, cache3.update(prompt_keys[i], prompt_values[i], prompt_keys[i], i) timers = Timers() micro_bench_actual_attention(cache3, prompt_keys, num_layers=num_layers, - num_gen_steps=num_gen_steps, timers=timers) + num_gen_steps=num_gen_steps, timers=timers, + sparsity_type=sparsity_type) del cache3 times = timers.get_times() print("Average time (minus cache updates) is - ") diff --git a/test_attention_benchmark_fixed.py b/test_attention_benchmark_fixed.py index 7a1549a..5285a24 100644 --- a/test_attention_benchmark_fixed.py +++ b/test_attention_benchmark_fixed.py @@ -56,6 +56,8 @@ def parse_args(): help="Whether to run Loki without sparsity benchmark") parser.add_argument("--run-attention-query-key-sparsity", action="store_true", default=True, help="Whether to run Loki with attention-query-key sparsity benchmark") + parser.add_argument("--run-vanilla", action="store_true", default=True, + help="Whether to run vanilla attention benchmark (without Loki)") parser.add_argument("--sparsity-type", type=str, default="attention-query-key", help="Sparsity type to use: attention-query-key") @@ -108,6 +110,7 @@ def parse_args(): sparse_filename = f"compute_files/prompt_{args.cache_seq_len}_gen_{args.num_gen_steps}_topk_{topk}_topr_{topr}_sparse.json" loki_without_sparsity_filename = f"compute_files/prompt_{args.cache_seq_len}_gen_{args.num_gen_steps}_loki_without_sparsity_topk_{topk}_topr_{topr}.json" attention_query_key_filename = f"compute_files/prompt_{args.cache_seq_len}_gen_{args.num_gen_steps}_attention_query_key_topk_{topk}_topr_{topr}.json" + vanilla_filename = f"compute_files/prompt_{args.cache_seq_len}_gen_{args.num_gen_steps}_vanilla.json" # Variable to track if we have sparse results times_sparse = None @@ -234,6 +237,51 @@ def parse_args(): print(f"Error in Loki + Attention Query/Key Sparsity benchmark: {str(e)}") traceback.print_exc() + # Fourth run: Vanilla Attention (without Loki) + times_vanilla = None + if args.run_vanilla: + print(f"\nRunning Vanilla Attention benchmark (without Loki)...") + try: + # Free memory before vanilla benchmark + free_gpu_memory() + + _, times_vanilla = benchmark_attention( + prompt_length=args.cache_seq_len, + num_gen_steps=args.num_gen_steps, + batch_size=args.batch_size, + num_heads=args.num_heads, + num_layers=args.num_layers, + topk=topk, + topr=topr, + vanilla=True, # Use vanilla attention + pcatopk=False, # Don't use PCA TopK + dtype=torch.float16 + ) + + # Add missing attributes for consistent comparison with other methods + if times_vanilla: + # These operations don't exist in vanilla but are in other methods + missing_attrs = ['project', 'top-k', 'reshape-0', 'reshape-1', 'qk-matmul-2', 'query-key-sparsity'] + for attr in missing_attrs: + if attr not in times_vanilla: + times_vanilla[attr] = 0.0 + else: + print("Error: No timing results returned for Vanilla Attention benchmark") + + # Save vanilla results + print(f"Saving to {vanilla_filename}") + with open(vanilla_filename, "w") as f: + json.dump(times_vanilla, f, indent=2) + + # Print vanilla results + print("\nVanilla Attention Times:") + for key, value in times_vanilla.items(): + print(f" {key}: {value:.6f} s") + print(f"Net time (minus cache updates): {times_vanilla.get('total', 0) - times_vanilla.get('cache-update', 0):.6f} s") + except Exception as e: + print(f"Error in Vanilla Attention benchmark: {str(e)}") + traceback.print_exc() + # Save to CSV if requested if args.output_csv: import csv @@ -283,6 +331,19 @@ def parse_args(): 'net_time': times_attention_query_key.get('total', 0) - times_attention_query_key.get('cache-update', 0) }) wrote_results = True + + # Vanilla Attention row if available + if args.run_vanilla and times_vanilla: + writer.writerow({ + 'method': 'vanilla_attention', + 'prompt_length': args.cache_seq_len, + 'gen_steps': args.num_gen_steps, + 'top_d': args.top_d, + 'total_time': times_vanilla.get('total', 0), + 'cache_update_time': times_vanilla.get('cache-update', 0), + 'net_time': times_vanilla.get('total', 0) - times_vanilla.get('cache-update', 0) + }) + wrote_results = True if wrote_results: print(f"Results saved to CSV: {args.output_csv}") From 1d48a0d884c78cac4a412515db32d84165feec49 Mon Sep 17 00:00:00 2001 From: Sukriti Paul Date: Mon, 12 May 2025 15:10:23 -0400 Subject: [PATCH 4/4] Updated eval_compute.py --- methods/pca_topk/attention_benchmark_apex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/methods/pca_topk/attention_benchmark_apex.py b/methods/pca_topk/attention_benchmark_apex.py index a1692ea..84025bd 100644 --- a/methods/pca_topk/attention_benchmark_apex.py +++ b/methods/pca_topk/attention_benchmark_apex.py @@ -482,7 +482,7 @@ def benchmark_attention(batch_size=1, times_pca_topk = None if pcatopk: print("PCA TOPK Optimized") - for _ in range(10): + for _ in range(100): cache2 = PcaTopKCache() for i in range(num_layers): cache2.update(prompt_keys[i].transpose(0,1).contiguous(), @@ -505,7 +505,7 @@ def benchmark_attention(batch_size=1, times_vanilla = None if vanilla: print("Actual Attention") - for _ in range(10): + for _ in range(100): cache3= PcaTopKCache() for i in range(num_layers): cache3.update(prompt_keys[i], prompt_values[i], prompt_keys[i], i)