Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions sparsebit/quantization/modules/onnx/quantizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
quantize / dequantize is used in onnx export, with lower bits supported
in sparsebit.quantization.quantizers.quant_tensor.py
"""
import torch
from torch.onnx.symbolic_helper import parse_args
import torch.onnx.symbolic_helper as sym_help


def analyze_min_max(L, R):
valid_symmetric_ranges = {
(-(2 ** (i - 1)), 2 ** (i - 1) - 1): i for i in range(2, 9)
}
valid_asymmetric_ranges = {(0, 2**i - 1): i for i in range(2, 9)}
valid_asymmetric_ranges[(0, 1)] = 1
if (L, R) in valid_symmetric_ranges:
return valid_symmetric_ranges[(L, R)], True
elif (L, R) in valid_asymmetric_ranges:
return valid_asymmetric_ranges[(L, R)], False

# no valid range in <=8bit types
return None, None


@parse_args("v", "v", "v", "i", "i", "i", "b")
def onnx_quantize(
g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127, extra_info=False
):
bit, is_symmetric = analyze_min_max(quant_min, quant_max)
assert (
bit is not None and is_symmetric is not None
), "the range ({}, {}) does not identify a valid data_type with bits<=8".format(
quant_min, quant_max
)

if isinstance(scale, float):
scale = torch.tensor(scale)
scale = scale.to(torch.float32)
if isinstance(zero_point, int):
zero_point = torch.tensor(zero_point)
if is_symmetric:
zero_point = zero_point.to(torch.int8)
else:
zero_point = zero_point.to(torch.uint8)

kwargs = {"axis_i": axis}
quant_op_name = "QuantizeLinear"
dequant_op_name = "DequantizeLinear"
if extra_info:
kwargs["dtype_s"] = "{}int{}".format("s" if is_symmetric else "u", str(bit))
kwargs["bits_i"] = bit
# change the operator domain, to avoid onnx.checker.check_model failure
quant_op_name = "Sparsebit::{}".format(quant_op_name)
dequant_op_name = "Sparsebit::{}".format(dequant_op_name)
quant_op = g.op(quant_op_name, inputs, scale, zero_point, **kwargs)
dequant_op = g.op(dequant_op_name, quant_op, scale, zero_point, **kwargs)
return dequant_op


class QuantizeFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inputs,
scale,
zero_point,
axis,
quant_min=-128,
quant_max=127,
per_channel=True,
extra_info=False,
):
if per_channel:
return torch.fake_quantize_per_channel_affine(
inputs, scale, zero_point, axis, quant_min, quant_max
)
else:
return torch.fake_quantize_per_tensor_affine(
inputs, scale, zero_point, quant_min, quant_max
)

@staticmethod
def backward(ctx, grad):
return (None,) * 8

@staticmethod
def symbolic(
g: torch.Graph,
inputs: torch.Value,
scale: torch.Value,
zero_point: torch.Value,
axis: int,
quant_min: int = -128,
quant_max: int = 127,
per_channel: bool = True,
extra_info: bool = False,
):
return onnx_quantize(
g, inputs, scale, zero_point, axis, quant_min, quant_max, extra_info
)
117 changes: 28 additions & 89 deletions sparsebit/quantization/quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,95 +226,34 @@ def export_onnx(
verbose=False,
extra_info=False,
):
self.eval()
self.set_quant(w_quant=True, a_quant=True) # quant must prepared before export
for n, m in self.model.named_modules():
if isinstance(m, Quantizer):
m.enable_export_onnx()
if m.bit != 8:
assert (
extra_info
), "You must set extra_info=True when export a model with {}bit".format(
m.bit
)

torch.onnx.export(
self.model.cpu(),
dummy_data,
name,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
verbose=verbose,
from sparsebit.quantization.tools.onnx_export_wrapper import (
enable_onnx_export,
enable_extra_info_export,
)
for n, m in self.model.named_modules():
if isinstance(m, Quantizer):
m.disable_export_onnx()

if extra_info:
self.add_extra_info_to_onnx(name)

def add_extra_info_to_onnx(self, onnx_path):
onnx_model = onnx.load(onnx_path)
extra_onnx_path = onnx_path.replace(".onnx", "_extra.onnx")
tensor_inputs = {}
tensor_outputs = {}
nodes = {}
for op in onnx_model.graph.node:
nodes[op.name] = op
for inp in op.input:
if inp not in tensor_outputs:
tensor_outputs[inp] = []
tensor_outputs[inp].append(op.name)
for outp in op.output:
if outp not in tensor_inputs:
tensor_inputs[outp] = []
tensor_inputs[outp].append(op.name)

op_pos = 0
skipped_modules = set()
for name, module in self.model.named_modules():
if (
module == self.model
or isinstance(module, (Observer, Quantizer, QIdentity, Clone))
or module in skipped_modules
):
continue
if isinstance(module, QuantOpr):
for submodule in module.children():
if not isinstance(submodule, QuantOpr):
skipped_modules.add(submodule)

while op_pos < len(onnx_model.graph.node) and (
onnx_model.graph.node[op_pos].op_type
in ["QuantizeLinear", "DequantizeLinear", "Constant"]
):
op_pos += 1
onnx_op = onnx_model.graph.node[op_pos]
op_pos += 1

if isinstance(module, QuantOpr) and getattr(
module.input_quantizer, "is_enable", False
):
input_dequant = nodes[tensor_inputs[onnx_op.input[0]][0]]
input_quant = nodes[tensor_inputs[input_dequant.input[0]][0]]
input_dequant.attribute.append(
onnx.helper.make_attribute("bits", module.input_quantizer.bit)
)
input_quant.attribute.append(
onnx.helper.make_attribute("bits", module.input_quantizer.bit)
)
self.eval()
self.set_quant(w_quant=True, a_quant=True) # quant must prepared before export

if isinstance(module, QuantOpr) and getattr(
module.weight_quantizer, "is_enable", False
):
weight_dequant = nodes[tensor_inputs[onnx_op.input[1]][0]]
weight_quant = nodes[tensor_inputs[weight_dequant.input[0]][0]]
weight_dequant.attribute.append(
onnx.helper.make_attribute("bits", module.weight_quantizer.bit)
)
weight_quant.attribute.append(
onnx.helper.make_attribute("bits", module.weight_quantizer.bit)
)
onnx.save(onnx_model, extra_onnx_path)
with enable_onnx_export(self.model, extra_info=extra_info):
torch.onnx.export(
self.model.cpu(),
dummy_data,
name,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
verbose=verbose,
)
if extra_info:
with enable_extra_info_export(self.model):
torch.onnx.export(
self.model.cpu(),
dummy_data.cpu(),
name.replace(".onnx", "_external.onnx"),
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
verbose=verbose,
)
11 changes: 10 additions & 1 deletion sparsebit/quantization/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, config):
self.observer = build_observer(config, self.qdesc)
self.use_quant = False
self.export_onnx = False
self.extra_info = False
self.fake_fused = False
if self.cfg.QUANTIZER.DISABLE:
self.set_fake_fused()
Expand Down Expand Up @@ -56,7 +57,9 @@ def forward(self, x):
if self.is_enable:
scale, zero_point = self._qparams_preprocess(x)
if self.export_onnx:
x_dq = torch_fake_quant(x, scale, zero_point, self.qdesc)
x_dq = torch_fake_quant(
x, scale, zero_point, self.qdesc, self.extra_info
)
else:
x_dq = self._forward(x, scale, zero_point)
else:
Expand Down Expand Up @@ -94,6 +97,12 @@ def enable_export_onnx(self):
def disable_export_onnx(self):
self.export_onnx = False

def enable_extra_info(self):
self.extra_info = True

def disable_extra_info(self):
self.extra_info = False

def _broadcast_qparams(self, params):
dst_shape = [1] * self.dims
dst_shape[self.qdesc.ch_axis] = -1
Expand Down
22 changes: 15 additions & 7 deletions sparsebit/quantization/quantizers/quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
from sparsebit.quantization.common import Backend
from sparsebit.quantization.modules.onnx.quantizers import QuantizeFunc

if torch.cuda.is_available():
from torch.utils.cpp_extension import load
Expand Down Expand Up @@ -217,7 +218,7 @@ def ort_dqrange(scale, zero_point, qdesc):


# torch_fake_quant仅用作模型export to onnx使用
def torch_fake_quant(x_f, scale, zero_point, qdesc):
def torch_fake_quant(x_f, scale, zero_point, qdesc, extra_info: bool = False):
# lower_bound, upper_bound = qdesc.qrange
# set [0, 255] for quint and [-128, 127] for qint because onnx only support 8 bit
if qdesc._type.startswith("uint"):
Expand All @@ -232,18 +233,25 @@ def torch_fake_quant(x_f, scale, zero_point, qdesc):
zero_point = zero_point.reshape(-1).long().to(x_f.device)
else:
zero_point = zero_point.reshape(-1).int().to(x_f.device)
x_dq = torch.fake_quantize_per_channel_affine(
x_f, scale, zero_point, ch_axis, lower_bound, upper_bound
)
is_perchannel = True
elif scale.numel() == 1: # pertensor
ch_axis = None
scale = scale.item()
if torch.__version__.startswith("1.9"): # fix bug in 1.9.x
zero_point = zero_point.long().item()
else:
zero_point = zero_point.int().item()
x_dq = torch.fake_quantize_per_tensor_affine(
x_f, scale, zero_point, lower_bound, upper_bound
)
is_perchannel = False
else:
raise TypeError("scale / zeropoint is not allowed to be an empty tensor")
x_dq = QuantizeFunc.apply(
x_f,
scale,
zero_point,
ch_axis,
lower_bound,
upper_bound,
is_perchannel,
extra_info,
)
return x_dq
70 changes: 70 additions & 0 deletions sparsebit/quantization/tools/onnx_export_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from sparsebit.quantization.quantizers import Quantizer


class ExtraInfoContextManager:
def __init__(self, model):
self.model = model

def enable_extra_info(self):
for module in self.model.modules():
if isinstance(module, Quantizer):
module.enable_extra_info()

def disable_extra_info(self):
for module in self.model.modules():
if isinstance(module, Quantizer):
module.disable_extra_info()

def __enter__(self):
self.enable_extra_info()

def __exit__(self, exc_type, exc_value, exc_traceback):
self.disable_extra_info()


class ONNXExportContextManager:
def __init__(self, model, extra_info):
self.model = model
self.extra_info = extra_info

def enable_export_onnx(self):
for module in self.model.modules():
if isinstance(module, Quantizer):
module.enable_export_onnx()
# FIXME: if quantizer bit!=8, extra_info must be enabled
if module.bit != 8 and not self.extra_info:
assert (
False
), "8bit is supported by default. \
You must set extra_info=True when export a model with {}bit".format(
module.bit
)

def disable_export_onnx(self):
for module in self.model.modules():
if isinstance(module, Quantizer):
module.disable_export_onnx()

def __enter__(self):
self.enable_export_onnx()

def __exit__(self, exc_type, exc_value, exc_traceback):
self.disable_export_onnx()


def enable_extra_info_export(model):
"""
Usage:
with enable_extra_info_export(model):
torch.onnx.export(model, ...)
"""
return ExtraInfoContextManager(model)


def enable_onnx_export(model, extra_info=False):
"""
Usage:
with enable_onnx_export(model):
torch.onnx.export(model, ...)
"""
return ONNXExportContextManager(model, extra_info)