2727
2828def quantize_model (model , args , tokenizer ):
2929 """
30- Quantize a PyTorch model using ModelOpt quantization.
30+ Quantize a PyTorch model using ModelOpt post-training quantization (PTQ) .
3131
32- This function performs post-training quantization (PTQ) on the model using
33- calibration data from the provided tokenizer . It supports both FP8 and NVFP4
34- quantization formats .
32+ This function applies quantization to reduce model precision for faster inference
33+ while maintaining acceptable accuracy . It uses calibration data generated from
34+ the provided tokenizer to determine optimal quantization parameters .
3535
36+ Supported quantization formats:
37+ - fp8: 8-bit floating point quantization
38+ - nvfp4: 4-bit NVIDIA floating point quantization
3639 Args:
37- model: PyTorch model to quantize
38- args: Arguments containing quantization format and debug settings
39- tokenizer: Tokenizer for creating calibration dataloader
40+ model: PyTorch model to quantize. Must be in evaluation mode.
41+ args: Command line arguments containing quant_format and debug
42+ tokenizer: Hugging Face tokenizer for creating calibration data
4043
4144 Returns:
42- Quantized model with reduced precision weights and activations
43-
44- Raises:
45- RuntimeError: If unsupported quantization format is specified
45+ Quantized model
4646 """
4747 # Create calibration dataloader for quantization
4848 calib_dataloader = get_dataset_dataloader (
@@ -51,9 +51,9 @@ def quantize_model(model, args, tokenizer):
5151 num_samples = 512 ,
5252 device = "cuda:0" ,
5353 )
54- if args .qformat == "fp8" :
54+ if args .quant_format == "fp8" :
5555 quant_cfg = mtq .FP8_DEFAULT_CFG
56- elif args .qformat == "nvfp4" :
56+ elif args .quant_format == "nvfp4" :
5757 quant_cfg = mtq .NVFP4_DEFAULT_CFG
5858 else :
5959 raise RuntimeError ("Unsupported quantization format" )
@@ -108,7 +108,38 @@ def forward(self, input):
108108 return torch .nn .functional .linear (input , weight , self .bias )
109109
110110
111- def convert_linear_to_tensorrt_quantized (model , model_name ):
111+ def load_quantization_config (model_name ):
112+ """
113+ Load quantization configuration from a Hugging Face model.
114+ Args:
115+ model_name (str): Local directory path or model identifier
116+ Returns:
117+ dict or None: Quantization configuration. None if no config found.
118+ """
119+ # Determine if model_name is a local directory or needs to be downloaded
120+ if os .path .isdir (model_name ):
121+ model_path = model_name
122+ else :
123+ # Download model from Hugging Face Hub
124+ model_path = snapshot_download (
125+ model_name ,
126+ local_files_only = huggingface_hub .constants .HF_HUB_OFFLINE ,
127+ ignore_patterns = ["original/**/*" ],
128+ revision = None ,
129+ )
130+ hf_quant_config = None
131+ # Load and parse quantization configuration
132+ hf_quant_config_path = f"{ model_path } /hf_quant_config.json"
133+ if os .path .exists (hf_quant_config_path ):
134+ with open (hf_quant_config_path , "r" ) as f :
135+ hf_quant_config = json .load (f )
136+ hf_quant_config = hf_quant_config ["quantization" ]
137+ hf_quant_config ["model_path" ] = model_path
138+
139+ return hf_quant_config
140+
141+
142+ def convert_linear_to_tensorrt_quantized (model , hf_quant_config ):
112143 """
113144 Convert linear layers in a model to TensorRT quantized versions from pre-quantized weights.
114145
@@ -119,58 +150,37 @@ def convert_linear_to_tensorrt_quantized(model, model_name):
119150
120151 The function:
121152 1. Loads quantization scales from Hugging Face model files (SafeTensors)
122- 2. Parses quantization configuration from hf_quant_config.json
123- 3. Replaces standard linear layers with TensorRTQuantizedLinear layers
124- 4. Applies appropriate quantization based on the model's quantization format
153+ 2. Replaces standard linear layers with TensorRTQuantizedLinear layers
154+ 3. Applies appropriate quantization based on the model's quantization format
125155
126156 Note: This function only quantizes linear operations and is intended for use
127157 with pre-quantized Hugging Face models that have been quantized using ModelOpt.
128158
129159 Args:
130160 model: PyTorch model to quantize
131- model_name: Path to Hugging Face model directory or model identifier
161+ hf_quant_config: Quantization configuration
132162
133163 Returns:
134164 Model with quantized linear layers
135165
136166 Raises:
137167 RuntimeError: If quantization config is not found or unsupported format
138168 """
139- # Determine if model_name is a local directory or needs to be downloaded
140- if os .path .isdir (model_name ):
141- hf_folder = model_name
142- else :
143- # Download model from Hugging Face Hub
144- hf_folder = snapshot_download (
145- model_name ,
146- local_files_only = huggingface_hub .constants .HF_HUB_OFFLINE ,
147- ignore_patterns = ["original/**/*" ],
148- revision = None ,
149- )
150-
169+ model_path = hf_quant_config ["model_path" ]
151170 # Load all tensors from SafeTensors files
152171 tensors = {}
153- for file in os .listdir (hf_folder ):
172+ for file in os .listdir (model_path ):
154173 if file .endswith (".safetensors" ):
155174 with safe_open (
156- os .path .join (hf_folder , file ), framework = "pt" , device = "cpu"
175+ os .path .join (model_path , file ), framework = "pt" , device = "cpu"
157176 ) as f :
158177 tensor_names = f .keys ()
159178 for name in tensor_names :
160179 tensors [name ] = f .get_tensor (name )
161180
162- # Load and parse quantization configuration
163- hf_quant_config_path = f"{ hf_folder } /hf_quant_config.json"
164- if os .path .exists (hf_quant_config_path ):
165- with open (hf_quant_config_path , "r" ) as f :
166- hf_quant_config = json .load (f )
167- hf_quant_config = hf_quant_config ["quantization" ]
168-
169- hf_quant_algo = hf_quant_config .pop ("quant_algo" , None )
170- if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4" :
171- raise RuntimeError ("Only FP8 or NVFP4 quantization is supported" )
172- else :
173- raise RuntimeError ("No quantization config found" )
181+ hf_quant_algo = hf_quant_config .get ("quant_algo" , None )
182+ if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4" :
183+ raise RuntimeError ("Only FP8 or NVFP4 quantization is supported" )
174184
175185 # Iterate through all modules in the model
176186 for name , module in model .named_modules ():
0 commit comments