2020from sparsebit .quantization .quantizers import Quantizer
2121from sparsebit .quantization .tools import QuantizationErrorProfiler
2222from sparsebit .quantization .converters import simplify , fuse_operations
23+ < << << << HEAD
2324from sparsebit .quantization .quant_tracer import QTracer
25+ == == == =
26+ from sparsebit .quantization .regularizers import build_regularizer
27+ > >> >> >> 8 bd7bbc ... add regularizer
2428
2529
2630__all__ = ["QuantModel" ]
@@ -35,6 +39,7 @@ def __init__(self, model: nn.Module, config):
3539 self ._run_simplifiers ()
3640 self ._convert2quantmodule ()
3741 self ._build_quantizer ()
42+ self ._build_regularizer ()
3843 self ._run_fuse_operations ()
3944
4045 def _convert2quantmodule (self ):
@@ -133,6 +138,7 @@ def _sub_build(src, module_name):
133138 update_config (_config , "A" , _sub_build (self .cfg .A , node .target ))
134139 identity_module .build_quantizer (_config )
135140
141+ < << << << HEAD
136142 def _trace (self , model ):
137143 skipped_modules = self .cfg .SKIP_TRACE_MODULES
138144 tracer = QTracer (skipped_modules )
@@ -141,12 +147,19 @@ def _trace(self, model):
141147 traced = fx .GraphModule (tracer .root , graph , name )
142148 traced .graph .print_tabular ()
143149 return traced
150+ == == == =
151+ def _build_regularizer (self ):
152+ if self .cfg .REGULARIZER .ENABLE :
153+ self .regularizer = build_regularizer (self .cfg )
154+ else :
155+ self .regularizer = None
156+ >> >> >> > 8 bd7bbc ... add regularizer
144157
145158 def _run_simplifiers (self ):
146159 self .model = simplify (self .model )
147160
148161 def _run_fuse_operations (self ):
149- if self .cfg .SCHEDULE .BN_TUNING : # first disable fuse bn
162+ if self .cfg .SCHEDULE .BN_TUNING : # first disable fuse bn
150163 update_config (self .cfg .SCHEDULE , "FUSE_BN" , False )
151164 self .model = fuse_operations (self .model , self .cfg .SCHEDULE )
152165 self .model .graph .print_tabular ()
@@ -167,7 +180,9 @@ def batchnorm_tuning(self):
167180 yield
168181 self .model .eval ()
169182 update_config (self .cfg .SCHEDULE , "FUSE_BN" , True )
170- self .model = fuse_operations (self .model , self .cfg .SCHEDULE , custom_fuse_list = ["fuse_bn" ])
183+ self .model = fuse_operations (
184+ self .model , self .cfg .SCHEDULE , custom_fuse_list = ["fuse_bn" ]
185+ )
171186 self .set_quant (w_quant = False , a_quant = False )
172187
173188 def prepare_calibration (self ):
@@ -235,6 +250,12 @@ def set_quant(self, w_quant=False, a_quant=False):
235250 if isinstance (m , QuantOpr ):
236251 m .set_quant (w_quant , a_quant )
237252
253+ def get_regularizer_loss (self ):
254+ if self .regularizer is None :
255+ return torch .tensor (0. ).to (self .device )
256+ else :
257+ return self .regularizer (self .model )
258+
238259 def export_onnx (
239260 self ,
240261 dummy_data ,
0 commit comments