Skip to content

Commit c512272

Browse files
committed
add regularizer
1 parent 3ef48b2 commit c512272

File tree

5 files changed

+106
-0
lines changed

5 files changed

+106
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
BACKEND: virtual
2+
W:
3+
QSCHEME: per-channel-symmetric
4+
QUANTIZER:
5+
TYPE: lsq
6+
BIT: 4
7+
A:
8+
QSCHEME: per-tensor-affine
9+
QUANTIZER:
10+
TYPE: lsq
11+
BIT: 4
12+
REGULARIZER:
13+
ENABLE: True
14+
TYPE: dampen
15+
LAMBDA: 0.01
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
REGULARIZERS_MAP = {}
2+
3+
4+
def register_regularizer(regularizer):
5+
REGULARIZERS_MAP[regularizer.TYPE.lower()] = regularizer
6+
return regularizer
7+
8+
9+
from .base import Regularizer
10+
from . import dampen
11+
12+
13+
def build_regularizer(config):
14+
regularizer = REGULARIZERS_MAP[config.REGULARIZER.TYPE.lower()](config)
15+
return regularizer
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
class Regularizer(object):
2+
def __init__(self, config):
3+
self.config = config
4+
5+
def __call__(self):
6+
pass
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
3+
from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
4+
from sparsebit.quantization.regularizers import register_regularizer
5+
6+
7+
@register_regularizer
8+
class Regularizer(BaseRegularizer):
9+
TYPE = "Dampen"
10+
11+
def __init__(self, config):
12+
super(Regularizer, self).__init__(config)
13+
self.config = config
14+
self._lambda = config.REGULARIZER.LAMBDA
15+
16+
def _get_loss(self, x, quantizer):
17+
18+
x_q = quantizer(x)
19+
20+
qmin, qmax = quantizer.qdesc.qrange
21+
22+
scale, zero_point = quantizer._qparams_preprocess(x)
23+
24+
scale = scale.detach()
25+
zero_point = zero_point.detach()
26+
27+
min_val = (qmin - zero_point) * scale
28+
29+
max_val = (qmax - zero_point) * scale
30+
31+
x_c = torch.min(torch.max(x, min_val), max_val)
32+
33+
loss = (x_q - x_c) ** 2
34+
35+
loss = loss.sum()
36+
37+
return loss
38+
39+
def __call__(self, model):
40+
loss = 0.0
41+
for n, m in model.named_modules():
42+
if (
43+
hasattr(m, "weight")
44+
and hasattr(m, "weight_quantizer")
45+
and m.weight_quantizer
46+
and m.weight_quantizer.is_enable
47+
):
48+
loss += self._get_loss(m.weight, m.weight_quantizer)
49+
return loss * self._lambda
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
3+
from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
4+
from sparsebit.quantization.regularizers import register_regularizer
5+
6+
7+
@register_regularizer
8+
class Regularizer(BaseRegularizer):
9+
TYPE = "Pact"
10+
11+
def __init__(self, config):
12+
super(Regularizer, self).__init__(config)
13+
self.config = config
14+
self._lambda = config.REGULARIZER.LAMBDA
15+
16+
def __call__(self, model):
17+
loss = 0.0
18+
for n, p in model.named_parameters():
19+
if "alpha" in n:
20+
loss += (p ** 2).sum()
21+
return loss * self._lambda

0 commit comments

Comments
 (0)