Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/llm_eval/lm_eval_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
from lm_eval.api.model import T
from lm_eval.models.huggingface import HFLM
from quantization_utils import quantize_model
from sparse_attention_utils import sparsify_model

import modelopt.torch.opt as mto
from modelopt.torch.quantization.utils import is_quantized
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified


def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T:
Expand All @@ -57,9 +59,20 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
calib_size = arg_dict.pop("calib_size", 512)
compress = arg_dict.pop("compress", False)

# Sparse attention arguments
sparse_cfg = arg_dict.pop("sparse_cfg", None)

additional_config = {} if additional_config is None else additional_config
additional_config = {k: v for k, v in additional_config.items() if v is not None}

# Force eager attention if sparse attention is requested
if sparse_cfg:
additional_config["attn_implementation"] = "eager"
warnings.warn(
"Sparse attention requires attn_implementation='eager'. "
"Forcing eager attention implementation."
)

# Enable automatic save/load of modelopt state huggingface checkpointing
mto.enable_huggingface_checkpointing()

Expand All @@ -85,6 +98,15 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
compress=compress,
)

if sparse_cfg:
if is_attn_sparsified(model_obj.model):
warnings.warn("Skipping sparse attention: model already has sparse attention applied.")
else:
sparsify_model(
model=model_obj,
sparse_cfg=sparse_cfg,
)

return model_obj


Expand Down Expand Up @@ -120,6 +142,11 @@ def setup_parser_with_modelopt_args():
action="store_true",
help="Compress the model after quantization",
)
parser.add_argument(
"--sparse_cfg",
type=str,
help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)",
)
return parser


Expand All @@ -142,6 +169,7 @@ def setup_parser_with_modelopt_args():
"calib_batch_size": args.calib_batch_size,
"calib_size": args.calib_size,
"compress": args.compress,
"sparse_cfg": args.sparse_cfg,
}
)

Expand Down
25 changes: 25 additions & 0 deletions examples/llm_eval/mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from fire import Fire
from modeling import EvalModel, select_model
from quantization_utils import MAX_SEQ_LEN, get_tokenizer, quantize_model
from sparse_attention_utils import sparsify_model
from tqdm import tqdm

try:
Expand All @@ -56,6 +57,7 @@
LLM = None # type: ignore[misc]
import modelopt.torch.opt as mto
from modelopt.torch.quantization.utils import is_quantized
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down Expand Up @@ -227,6 +229,7 @@ def main(
batch_size: int = 0,
calib_size: int = 512,
dtype: str = "bfloat16",
sparse_cfg: str | None = None,
**kwargs,
):
random.seed(RAND_SEED)
Expand Down Expand Up @@ -263,6 +266,14 @@ def main(
max_batch_size=1,
)
else:
# Force eager attention if sparse attention is requested
if sparse_cfg:
kwargs["attn_implementation"] = "eager"
warnings.warn(
"Sparse attention requires attn_implementation='eager'. "
"Forcing eager attention implementation."
)

model = select_model(
max_input_length=MAX_SEQ_LEN, max_output_length=2, dtype=dtype, **kwargs
)
Expand All @@ -283,6 +294,20 @@ def main(
auto_quantize_bits=auto_quantize_bits,
)

# Apply sparse attention if requested
if sparse_cfg:
model.load()

if is_attn_sparsified(model.model):
warnings.warn(
"Skipping sparse attention: model already has sparse attention applied."
)
else:
sparsify_model(
model=model,
sparse_cfg=sparse_cfg,
)

for subject in tqdm(subjects):
dev_df = pd.read_csv(os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None)[
:ntrain
Expand Down
5 changes: 5 additions & 0 deletions examples/llm_eval/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class SeqToSeqModel(EvalModel):
lora_path: str = ""
device: str = "cuda"
load_8bit: bool = False
attn_implementation: str | None = None

def load(self):
if self.model is None:
Expand All @@ -188,6 +189,8 @@ def load(self):
if self.load_8bit:
args.update(device_map="auto", load_in_8bit=True)
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
if self.attn_implementation:
args["attn_implementation"] = self.attn_implementation
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path, **args)
print_gpu_utilization()
if self.lora_path:
Expand Down Expand Up @@ -241,6 +244,8 @@ def load(self):
if self.load_8bit:
args.update(device_map="auto", load_in_8bit=True)
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
if self.attn_implementation:
args["attn_implementation"] = self.attn_implementation
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, trust_remote_code=True, **args
)
Expand Down
111 changes: 111 additions & 0 deletions examples/llm_eval/sparse_attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for sparse attention integration with llm_eval."""

import modelopt.torch.sparsity.attention_sparsity as mtsa

# Custom sparse attention configurations
CUSTOM_SPARSE_CONFIG = {
"SPARSE_CONSERVATIVE": {
"sparse_cfg": {
"*attn*": {
"method": "flash_skip_softmax",
"threshold": {"prefill": 5e-4, "decode": 1e-5},
"br": 128,
"bc": 128,
"backend": "pytorch",
"enable": True,
},
"default": {"enable": False},
},
},
"SPARSE_AGGRESSIVE": {
"sparse_cfg": {
"*attn*": {
"method": "flash_skip_softmax",
"threshold": {"prefill": 5e-3, "decode": 5e-4},
"br": 128,
"bc": 128,
"backend": "pytorch",
"enable": True,
},
"default": {"enable": False},
},
},
}


def _extract_model(model_obj):
"""Extract actual model from wrapper (HFLM or EvalModel)."""
if hasattr(model_obj, "gpt2"):
return model_obj.gpt2
elif hasattr(model_obj, "model"):
return model_obj.model
else:
return model_obj


def sparsify_model(
model,
sparse_cfg: str,
backend=None,
):
"""Apply sparse attention to model with optional RULER calibration.

Args:
model: Model wrapper (HFLM or EvalModel) or raw model
sparse_cfg: Sparse attention config name or dict
backend: Backend to use (optional, overrides config backend)

Returns:
The model with sparse attention applied

Note:
Calibration is automatically triggered if the config contains a 'calibration' field.
The calibration will auto-generate RULER dataset from the model's tokenizer.
"""
# Extract actual model
net = _extract_model(model)

# Resolve config
if isinstance(sparse_cfg, str):
# Try custom configs first
mtsa_cfg = CUSTOM_SPARSE_CONFIG.get(sparse_cfg)
if mtsa_cfg is None:
# Try predefined configs
mtsa_cfg = getattr(mtsa, sparse_cfg, None)
if mtsa_cfg is None:
raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}")
else:
mtsa_cfg = sparse_cfg

# Override backend if specified
if backend:
if isinstance(mtsa_cfg, dict) and "sparse_cfg" in mtsa_cfg:
modified_sparse_cfg = {}
for pattern, cfg in mtsa_cfg["sparse_cfg"].items():
modified_cfg = cfg.copy() if isinstance(cfg, dict) else cfg
if isinstance(modified_cfg, dict):
modified_cfg["backend"] = backend
modified_sparse_cfg[pattern] = modified_cfg
mtsa_cfg = {"sparse_cfg": modified_sparse_cfg}

# Apply sparsification
print(f"\nApplying sparse attention with config: {sparse_cfg}")
mtsa.sparsify(net, mtsa_cfg)
print("Sparse attention applied successfully!")

return model
Loading