Skip to content

Commit 8397203

Browse files
committed
add regularizer
1 parent c512272 commit 8397203

File tree

5 files changed

+35
-4
lines changed

5 files changed

+35
-4
lines changed

examples/quantization_aware_training/cifar10/basecase/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66
import warnings
77
from enum import Enum
8+
import math
89

910
import torch
1011
import torch.nn as nn
@@ -278,7 +279,7 @@ def train(
278279
# compute output
279280
output = model(images)
280281
ce_loss = criterion(output, target)
281-
regular_loss = get_regularizer_loss(model, is_pact, scale=regularizer_lambda)
282+
regular_loss = model.get_regularizer_loss() * schedule_value
282283
loss = ce_loss + regular_loss
283284

284285
# measure accuracy and record loss

examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ A:
99
QUANTIZER:
1010
TYPE: pact
1111
BIT: 4
12+
REGULARIZER:
13+
ENABLE: True
14+
TYPE: pact
15+
LAMBDA: 0.0001

sparsebit/quantization/quant_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@
4343
_C.A.OBSERVER.LAYOUT = "NCHW" # NCHW / NLC
4444
_C.A.SPECIFIC = []
4545

46+
_C.REGULARIZER = CN()
47+
_C.REGULARIZER.ENABLE = False
48+
_C.REGULARIZER.TYPE = ""
49+
_C.REGULARIZER.LAMBDA = 0.0
50+
4651

4752
def parse_qconfig(cfg_file):
4853
qconfig = _parse_config(cfg_file, default_cfg=_C)

sparsebit/quantization/quant_model.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
from sparsebit.quantization.quantizers import Quantizer
2121
from sparsebit.quantization.tools import QuantizationErrorProfiler
2222
from sparsebit.quantization.converters import simplify, fuse_operations
23+
<<<<<<< HEAD
2324
from sparsebit.quantization.quant_tracer import QTracer
25+
=======
26+
from sparsebit.quantization.regularizers import build_regularizer
27+
>>>>>>> 8bd7bbc... add regularizer
2428

2529

2630
__all__ = ["QuantModel"]
@@ -35,6 +39,7 @@ def __init__(self, model: nn.Module, config):
3539
self._run_simplifiers()
3640
self._convert2quantmodule()
3741
self._build_quantizer()
42+
self._build_regularizer()
3843
self._run_fuse_operations()
3944

4045
def _convert2quantmodule(self):
@@ -133,6 +138,7 @@ def _sub_build(src, module_name):
133138
update_config(_config, "A", _sub_build(self.cfg.A, node.target))
134139
identity_module.build_quantizer(_config)
135140

141+
<<<<<<< HEAD
136142
def _trace(self, model):
137143
skipped_modules = self.cfg.SKIP_TRACE_MODULES
138144
tracer = QTracer(skipped_modules)
@@ -141,12 +147,19 @@ def _trace(self, model):
141147
traced = fx.GraphModule(tracer.root, graph, name)
142148
traced.graph.print_tabular()
143149
return traced
150+
=======
151+
def _build_regularizer(self):
152+
if self.cfg.REGULARIZER.ENABLE:
153+
self.regularizer = build_regularizer(self.cfg)
154+
else:
155+
self.regularizer = None
156+
>>>>>>> 8bd7bbc... add regularizer
144157

145158
def _run_simplifiers(self):
146159
self.model = simplify(self.model)
147160

148161
def _run_fuse_operations(self):
149-
if self.cfg.SCHEDULE.BN_TUNING: # first disable fuse bn
162+
if self.cfg.SCHEDULE.BN_TUNING: # first disable fuse bn
150163
update_config(self.cfg.SCHEDULE, "FUSE_BN", False)
151164
self.model = fuse_operations(self.model, self.cfg.SCHEDULE)
152165
self.model.graph.print_tabular()
@@ -167,7 +180,9 @@ def batchnorm_tuning(self):
167180
yield
168181
self.model.eval()
169182
update_config(self.cfg.SCHEDULE, "FUSE_BN", True)
170-
self.model = fuse_operations(self.model, self.cfg.SCHEDULE, custom_fuse_list=["fuse_bn"])
183+
self.model = fuse_operations(
184+
self.model, self.cfg.SCHEDULE, custom_fuse_list=["fuse_bn"]
185+
)
171186
self.set_quant(w_quant=False, a_quant=False)
172187

173188
def prepare_calibration(self):
@@ -235,6 +250,12 @@ def set_quant(self, w_quant=False, a_quant=False):
235250
if isinstance(m, QuantOpr):
236251
m.set_quant(w_quant, a_quant)
237252

253+
def get_regularizer_loss(self):
254+
if self.regularizer is None:
255+
return torch.tensor(0.).to(self.device)
256+
else:
257+
return self.regularizer(self.model)
258+
238259
def export_onnx(
239260
self,
240261
dummy_data,

sparsebit/quantization/regularizers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def register_regularizer(regularizer):
77

88

99
from .base import Regularizer
10-
from . import dampen
10+
from . import dampen, pact
1111

1212

1313
def build_regularizer(config):

0 commit comments

Comments
 (0)