Skip to content

Commit bccb8da

Browse files
committed
Merge branch 'master' of github.com:IST-DASLab/FP-Quant
2 parents 8ed4fa6 + c1dc805 commit bccb8da

10 files changed

Lines changed: 80 additions & 12 deletions

README.md

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
1-
# FP-quantization-harness
1+
# FP Format Quantization Harness
22

3-
Repository for the the development of a recipe for efficient and accurate weight + activation quantization for low-bit FP formats (FP4, NVFP4, MXFP**B**).
3+
This is a harness for efficient and accurate weight-and-activation quantization for low-bit FP/INT formats, with and without microscaling, including FP4, NVFP4, and MXFP. These formats are compatible with the NVIDIA Blackwell GPU architecture.
4+
5+
The goal of the repository is to allow you to produce quantized models in these formats.
6+
Currently, the repository supports the standard microscaled MXFP4 format, together with standard methods such as RTN and GPTQ quantization for the weights. The main new approach supported--which we found to be particularly effective--is a variant of GPTQ (called GPTQ+Had) where a block-wise Hadamard transform is applied onto the weights and activations before quantization. Key to efficiency is that the Hadamard block size matches the microscaling format group size (16 or 32); in turn, this small Hadamard transform is automatically "fused" into our MatMul kernels.
7+
8+
The inference code to run models in the `MXFP` format (with speedups) can be found in the [QuTLASS](https://github.com/IST-DASLab/qutlass) repository.
49

510
### Repository structure
611
---
712

813
The repository is structured as follows:
914

10-
* `model_quant.py` - the main script for quantization of the Llama models
15+
* `model_quant.py` - the main script for quantization of Llama/Qwen models
1116
* `src/` - source code with implementation of all necessary functionality \
1217
```├── quantization``` - quantization functionality \
1318
```├── transforms``` - transform functionality \
1419
```├── utils``` - utility functions
1520

1621

22+
1723
### Usage
1824
---
1925

20-
Below is an example of the qat script usage:
26+
Below is an example of the model quantization script usage:
2127

2228
```shell
2329
MODEL=${MODEL:-"meta-llama/Llama-3.1-8B-Instruct"}
@@ -101,7 +107,7 @@ Above:
101107
* `--w_observer` - The observer to use for the weights (`mse` or `minmax`).
102108
* `--a_group_size` - The number of activations to quantize together.
103109
* `--parametrization` - Transform parameterization.
104-
* `--gptq` - Whether to use GPTQ quantization.
110+
* `--gptq` - Whether to use GPTQ quantization for the weights.
105111
* `--transform_class` - Transform class (`identity` or `hadamard`).
106112
* `--dataset_name_or_path` - Dataset to use.
107113
* `--sequence_length` - Sequence length.
@@ -110,4 +116,55 @@ Above:
110116
* `--save_path` - Path to save the quantized model.
111117
* `--real_quant` - Whether to save model in real quantization format.
112118
* `--eval_perplexity` - Whether to compute perplexity.
113-
* `--eval_openllm` - Whether to compute OpenLLMv1 scores.
119+
* `--eval_openllm` - Whether to compute OpenLLMv1 scores.
120+
121+
`real_quant` option produces models that are runnable on Blackwell architectures (`sm_120`) via transformers and vLLM (currently using the transformers [fork](https://github.com/huggingface/transformers/pull/38696/)).
122+
123+
124+
125+
### Accuracy Evaluations
126+
127+
The results below provide the evaluation results for quantized Llama-3 and Qwen-3 models
128+
on the OpenLLM v1 leaderboard. Specifically, we provide average metrics for the following tasks:
129+
* `mmlu_cot_llama` (exact_match, strict_match)
130+
* `arc_challenge_llama` (exact_match, strict_match)
131+
* `gsm8k_llama` (exact_match, strict_match)
132+
* `hellaswag` (acc_norm)
133+
* `winogrande` (acc)
134+
* `truthfulqa_mc2` (acc)
135+
136+
The results for Qwen3 exclude `arc_challenge_llama` as it turns out to be very noisy.
137+
138+
Below left column corresponds to **weight-only** quantization, right column corresponds to **weight-and-activation** quantization. Results for AWQ were produced via the dedicated [AutoAWQ fork](https://github.com/Godofnothing/AutoAWQ-FP).
139+
140+
**Llama-3.1-8B-Instruct**
141+
142+
<p float="left">
143+
<img src="assets/llama-3.1-8b-acc-weight_only.png" width="400" />
144+
<img src="assets/llama-3.1-8b-acc-weight_and_activation.png" width="400" />
145+
</p>
146+
147+
**Qwen-3-8B**
148+
149+
<p float="left">
150+
<img src="assets/qwen3-3-8b-acc-weight_only.png" width="400" />
151+
<img src="assets/qwen3-3-8b-acc-weight_and_activation.png" width="400" />
152+
</p>
153+
154+
*Notes*. For NVFP format without `hadamard` rotation GPTQ's average performance is below 0.65.
155+
By and large, `GPTQ+Had` appears to be the best method for preserving accuracy.
156+
157+
158+
### Inference speedups
159+
160+
Below we provide some performance numbers for end-2-end inference with QuTLASS kernels vs `bf16` baseline for Qwen3 models, on an RTX 5090 GPU.
161+
Please see the [QuTLASS](https://github.com/IST-DASLab/qutlass) repository for details on how to reproduce this.
162+
163+
<p float="left">
164+
<img src="assets/inference_speedup_qwen3_8b.png" width="400" />
165+
<img src="assets/inference_speedup_qwen3_14b.png" width="400" />
166+
</p>
167+
168+
### Contributors
169+
170+
This project is still in active development. So far, it has benefitted from contributions from Denis Kuznedelev, Andrei Panferov, Vage Egiazarian, Saleh Ashkboos, as well as Dan Alistarh, Michael Goin and Eldar Kurtic. The [QuTLASS](https://github.com/IST-DASLab/qutlass) repository is developed primarily by Roberto Lopez Castro, with help from Jiale Chen.
35.6 KB
Loading
35.3 KB
Loading
24.7 KB
Loading
24.2 KB
Loading
23.4 KB
Loading
23.4 KB
Loading

model_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def main():
316316
model = AutoModelForCausalLM.from_pretrained(
317317
args.model_name_or_path,
318318
torch_dtype=args.dtype,
319-
device_map=device, # to avoid errors when model is split on mulitple GPUs
319+
device_map=None if args.cpu_offload_modules else device,
320320
low_cpu_mem_usage=True,
321321
)
322322
model.config.use_cache = False
@@ -338,7 +338,7 @@ def main():
338338
args.num_sequences,
339339
args.seed
340340
)
341-
quantized_state_dict = gptq_quantization(model, calibration_data, args, device)
341+
quantized_state_dict = gptq_quantization(model, calibration_data, args, device=device)
342342
else:
343343
quantized_state_dict = rtn_quantization(model, args, device)
344344

src/quantization/gptq.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def gptq_quantization(
229229
) -> Optional[dict[str, torch.Tensor]]:
230230
print("GPTQ quantization...")
231231
orig_dtype = model.config.torch_dtype if args.dtype == "auto" else args.dtype
232+
activation_offload_device = "cpu" if args.cpu_offload_activations else None
232233
# State dict with quantized weights, scales and hadamards
233234
quantized_state_dict = {}
234235
# Define common transform kwargs
@@ -261,7 +262,11 @@ def gptq_quantization(
261262

262263
blocks = model.model.layers
263264
blocks[0] = blocks[0].to(device)
264-
blocks[0] = InputCollector(blocks[0], cpu_offload=False)
265+
blocks[0] = InputCollector(blocks[0], cpu_offload=activation_offload_device)
266+
267+
if args.cpu_offload_modules:
268+
model.get_input_embeddings().to(device)
269+
blocks[0].to(device)
265270

266271
for sample in calibration_data:
267272
try:
@@ -274,6 +279,9 @@ def gptq_quantization(
274279
input_kwargs = blocks[0].input_kwargs
275280
blocks[0] = blocks[0].module
276281

282+
if args.cpu_offload_modules:
283+
model.get_input_embeddings().cpu()
284+
277285
# Iterate over transformer blocks
278286
for block_idx, block in enumerate(blocks):
279287
print(f"Processing block {block_idx}...")
@@ -381,12 +389,15 @@ def _hook(_, inp, out):
381389
out = maybe_first_element(out)
382390
# change only first input argument
383391
if len(inp_args) > 0:
384-
inp_args[0].data = out
392+
inp_args[0].data = out.to(activation_offload_device)
385393
elif "hidden_states" in inp_kwargs:
386-
inp_kwargs["hidden_states"] = out
394+
inp_kwargs["hidden_states"] = out.to(activation_offload_device)
387395
else:
388396
raise ValueError("Unsupported block input format.")
389397

398+
if args.cpu_offload_modules:
399+
block = block.cpu()
400+
390401
# 10. Clean-up
391402
del gptq_handles
392403
del hooks

src/quantization/qconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def prepare_quantization_config(group_size: int, format: str) -> dict[str, Any]:
88
"forward_method": "abs_max",
99
"hadamard_group_size": group_size,
1010
"modules_to_not_convert": ["lm_head"],
11-
"quant_method": "quartet",
11+
"quant_method": "fp_quant",
1212
"store_master_weights": False
1313
}
1414
elif format == "nvfp":

0 commit comments

Comments
 (0)