File tree Expand file tree Collapse file tree 5 files changed +106
-0
lines changed
examples/quantization_aware_training/cifar10/basecase
sparsebit/quantization/regularizers Expand file tree Collapse file tree 5 files changed +106
-0
lines changed Original file line number Diff line number Diff line change 1+ BACKEND : virtual
2+ W :
3+ QSCHEME : per-channel-symmetric
4+ QUANTIZER :
5+ TYPE : lsq
6+ BIT : 4
7+ A :
8+ QSCHEME : per-tensor-affine
9+ QUANTIZER :
10+ TYPE : lsq
11+ BIT : 4
12+ REGULARIZER :
13+ ENABLE : True
14+ TYPE : dampen
15+ LAMBDA : 0.01
Original file line number Diff line number Diff line change 1+ REGULARIZERS_MAP = {}
2+
3+
4+ def register_regularizer (regularizer ):
5+ REGULARIZERS_MAP [regularizer .TYPE .lower ()] = regularizer
6+ return regularizer
7+
8+
9+ from .base import Regularizer
10+ from . import dampen
11+
12+
13+ def build_regularizer (config ):
14+ regularizer = REGULARIZERS_MAP [config .REGULARIZER .TYPE .lower ()](config )
15+ return regularizer
Original file line number Diff line number Diff line change 1+ class Regularizer (object ):
2+ def __init__ (self , config ):
3+ self .config = config
4+
5+ def __call__ (self ):
6+ pass
Original file line number Diff line number Diff line change 1+ import torch
2+
3+ from sparsebit .quantization .regularizers import Regularizer as BaseRegularizer
4+ from sparsebit .quantization .regularizers import register_regularizer
5+
6+
7+ @register_regularizer
8+ class Regularizer (BaseRegularizer ):
9+ TYPE = "Dampen"
10+
11+ def __init__ (self , config ):
12+ super (Regularizer , self ).__init__ (config )
13+ self .config = config
14+ self ._lambda = config .REGULARIZER .LAMBDA
15+
16+ def _get_loss (self , x , quantizer ):
17+
18+ x_q = quantizer (x )
19+
20+ qmin , qmax = quantizer .qdesc .qrange
21+
22+ scale , zero_point = quantizer ._qparams_preprocess (x )
23+
24+ scale = scale .detach ()
25+ zero_point = zero_point .detach ()
26+
27+ min_val = (qmin - zero_point ) * scale
28+
29+ max_val = (qmax - zero_point ) * scale
30+
31+ x_c = torch .min (torch .max (x , min_val ), max_val )
32+
33+ loss = (x_q - x_c ) ** 2
34+
35+ loss = loss .sum ()
36+
37+ return loss
38+
39+ def __call__ (self , model ):
40+ loss = 0.0
41+ for n , m in model .named_modules ():
42+ if (
43+ hasattr (m , "weight" )
44+ and hasattr (m , "weight_quantizer" )
45+ and m .weight_quantizer
46+ and m .weight_quantizer .is_enable
47+ ):
48+ loss += self ._get_loss (m .weight , m .weight_quantizer )
49+ return loss * self ._lambda
Original file line number Diff line number Diff line change 1+ import torch
2+
3+ from sparsebit .quantization .regularizers import Regularizer as BaseRegularizer
4+ from sparsebit .quantization .regularizers import register_regularizer
5+
6+
7+ @register_regularizer
8+ class Regularizer (BaseRegularizer ):
9+ TYPE = "Pact"
10+
11+ def __init__ (self , config ):
12+ super (Regularizer , self ).__init__ (config )
13+ self .config = config
14+ self ._lambda = config .REGULARIZER .LAMBDA
15+
16+ def __call__ (self , model ):
17+ loss = 0.0
18+ for n , p in model .named_parameters ():
19+ if "alpha" in n :
20+ loss += (p ** 2 ).sum ()
21+ return loss * self ._lambda
You can’t perform that action at this time.
0 commit comments