diff --git a/examples/experimental/torch/classification/README.md b/examples/experimental/torch/classification/README.md new file mode 100644 index 00000000000..732da8f7bb9 --- /dev/null +++ b/examples/experimental/torch/classification/README.md @@ -0,0 +1,40 @@ +# FracBits mixed-precision quantization algorithm + +This provides sample configurations of FracBits mixed-precision quantization algorithm for image classification tasks. + +## Prerequiste + +Please follow [installation guide](../../../torch/classification/README.md#installation) and [dataset preperation guide](../../../torch/classification/README.md#dataset-preparation) of NNCF PyTorch classification examples. + +## Compress FP32 model with FracBits + +You can run the FracBits mixed-precision quantization algorithm with the pre-defined configuration file. + +```bash +cd examples/experimental/torch/classification +python fracbits.py -m train -c -j --data --log-dir +``` + +The following describes each argument. + +- `-c`: FracBits configuration file path. You can find it from `examples/experimental/torch/classification/fracbits_configs`. +- `-j`: The number of PyTorch dataloader workers. +- `--data`: Directory path of the dataset. +- `--log-dir`: Directory path to save log files, tensorboard logs, and model checkpoints. + +We provide configurations for three model architectures: `inception_v3`, `mobilenet_v2`, and `resnet50`. Our configurations almost uses the ImageNet dataset except `mobilenet_v2` which also has a configuration for the CIFAR100 dataset. + +## Results for FracBits + +| Model | Compression algorithm | Dataset | Accuracy (Drop) % | NNCF config file | Compression rate | +| :----------: | :-------------------: | :------: | :---------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------: | +| MobileNet-V2 | FracBits | CIFAR100 | 67.26 (0.45) | [mobilenet_v2_cifar100_mixed_int_fracbits_msize.json](./configs/mobilenet_v2_cifar100_mixed_int_fracbits_msize.json) | 1.5 | +| Inception-V3 | FracBits | ImageNet | 78.16 (-0.82) | [inception_v3_imagenet_mixed_int_fracbits_msize.json](./configs/inception_v3_imagenet_mixed_int_fracbits_msize.json) | 1.51 | +| MobileNet-V2 | FracBits | ImageNet | 71.19 (0.68) | [mobilenet_v2_imagenet_mixed_int_fracbits_msize.json](./configs/mobilenet_v2_imagenet_mixed_int_fracbits_msize.json) | 1.53 | +| ResNet-50 | FracBits | ImageNet | 76.12 (0.04) | [resnet50_imagenet_mixed_int_fracbits_msize.json](./configs/resnet50_imagenet_mixed_int_fracbits_msize.json) | 1.54 | + +- We used a NVIDIA V100 x 8 machine to obtain all results except MobileNet-V2, CIFAR100 experiment. +- Model accuracy is obtained by averaging on 5 repeats. +- Absolute accuracy drop is compared to FP32 model accuracy reported in [Results for quantization](../../../torch/classification/README.md#results-for-quantization). +- Compression rate is about the reduced model size compared to the initial one. The model initial state starts from INT8 quantization, so compression rate = 1.5 means that the model size is reduced to 2/3 compared to the INT8 model. +- Model size is the total number of bits in model weights. It is computed by $\sum_i \textrm{num-params}_i \times \textrm{bitwidth}_i$ where $\textrm{num-params}_i$ is the number of parameters of $i$-th layer and $\textrm{bitwidth}_i$ is the bit-width of $i$-th layer. diff --git a/examples/experimental/torch/classification/fracbits.py b/examples/experimental/torch/classification/fracbits.py new file mode 100644 index 00000000000..0b6d0921871 --- /dev/null +++ b/examples/experimental/torch/classification/fracbits.py @@ -0,0 +1,22 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +# pylint: disable=unused-import + +import sys + +from examples.torch.classification.main import main +from nncf.experimental.torch.fracbits.builder import FracBitsQuantizationBuilder + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/examples/experimental/torch/classification/fracbits_configs/inception_v3_imagenet_mixed_int_fracbits_msize.json b/examples/experimental/torch/classification/fracbits_configs/inception_v3_imagenet_mixed_int_fracbits_msize.json new file mode 100644 index 00000000000..fee34f55ed8 --- /dev/null +++ b/examples/experimental/torch/classification/fracbits_configs/inception_v3_imagenet_mixed_int_fracbits_msize.json @@ -0,0 +1,46 @@ +{ + "model": "inception_v3", + "pretrained": true, + "input_info": { + "sample_size": [2, 3, 299, 299] + }, + "num_classes": 1000, + "batch_size" : 512, + "epochs": 1, + "multiprocessing_distributed" : true, + "find_unused_parameters": true, + "optimizer": { + "type": "SGD", + "base_lr": 1e-3, + "schedule_type": "multistep", + "steps": [ + 1 + ] + }, + "compression": { + "algorithm": "fracbits_quantization", + "overflow_fix": "disable", + "initializer": { + "range": { + "num_init_samples": 160 + } + }, + "ignored_scopes": [ + "Inception3/__add___0", + "Inception3/__add___1", + "Inception3/__add___2", + "Inception3/__mul___0", + "Inception3/__mul___1", + "Inception3/__mul___2", + "Inception3/cat_0" + ], + "freeze_epoch": -1, + "loss": { + "type": "model_size", + "compression_rate": 1.5, + "criteria": "L1", + "flip_loss": false, + "alpha": 40.0 + } + } +} diff --git a/examples/experimental/torch/classification/fracbits_configs/mobilenet_v2_cifar100_mixed_int_fracbits_msize.json b/examples/experimental/torch/classification/fracbits_configs/mobilenet_v2_cifar100_mixed_int_fracbits_msize.json new file mode 100644 index 00000000000..5a86b5c34f9 --- /dev/null +++ b/examples/experimental/torch/classification/fracbits_configs/mobilenet_v2_cifar100_mixed_int_fracbits_msize.json @@ -0,0 +1,33 @@ +{ + "model": "mobilenet_v2_32x32", + "pretrained": false, + "input_info": { + "sample_size": [2, 3, 32, 32] + }, + "num_classes": 100, + "batch_size": 256, + "optimizer": { + "type": "SGD", + "base_lr": 1e-3 + }, + "compression": { + "algorithm": "fracbits_quantization", + "overflow_fix": "disable", + "activations": { + "mode": "asymmetric" + }, + "weights": { + "mode": "asymmetric" + }, + "freeze_epoch": 4, + "loss": { + "type": "model_size", + "compression_rate": 1.5, + "criteria": "L1", + "flip_loss": false, + "alpha": 40.0 + } + }, + "epochs": 5, + "dataset": "CIFAR100" +} diff --git a/examples/experimental/torch/classification/fracbits_configs/mobilenet_v2_imagenet_mixed_int_fracbits_msize.json b/examples/experimental/torch/classification/fracbits_configs/mobilenet_v2_imagenet_mixed_int_fracbits_msize.json new file mode 100644 index 00000000000..5bedc4e676b --- /dev/null +++ b/examples/experimental/torch/classification/fracbits_configs/mobilenet_v2_imagenet_mixed_int_fracbits_msize.json @@ -0,0 +1,37 @@ +{ + "model": "mobilenet_v2", + "pretrained": true, + "input_info": { + "sample_size": [2, 3, 224, 224] + }, + "num_classes": 1000, + "batch_size" : 1024, + "epochs": 5, + "multiprocessing_distributed": true, + "find_unused_parameters": true, + "optimizer": { + "type": "SGD", + "base_lr": 1e-3, + "schedule_type": "multistep", + "steps": [ + 5 + ] + }, + "compression": { + "algorithm": "fracbits_quantization", + "overflow_fix": "disable", + "initializer": { + "range": { + "num_init_samples": 2560 + } + }, + "freeze_epoch": 4, + "loss": { + "type": "model_size", + "compression_rate": 1.5, + "criteria": "L1", + "flip_loss": false, + "alpha": 40.0 + } + } +} diff --git a/examples/experimental/torch/classification/fracbits_configs/resnet50_imagenet_mixed_int_fracbits_msize.json b/examples/experimental/torch/classification/fracbits_configs/resnet50_imagenet_mixed_int_fracbits_msize.json new file mode 100644 index 00000000000..026ee2ce760 --- /dev/null +++ b/examples/experimental/torch/classification/fracbits_configs/resnet50_imagenet_mixed_int_fracbits_msize.json @@ -0,0 +1,34 @@ +{ + "model": "resnet50", + "pretrained": true, + + "input_info": { + "sample_size": [1, 3, 224, 224] + }, + "num_classes": 1000, + "batch_size": 512, + "epochs": 1, + "multiprocessing_distributed": true, + "find_unused_parameters": true, + "optimizer": { + "type": "SGD", + "base_lr": 1e-3 + }, + "compression": { + "algorithm": "fracbits_quantization", + "overflow_fix": "disable", + "initializer": { + "range": { + "num_init_samples": 850 + } + }, + "freeze_epoch": -1, + "loss": { + "type": "model_size", + "compression_rate": 1.5, + "criteria": "L1", + "flip_loss": false, + "alpha": 40.0 + } + } +} diff --git a/examples/torch/common/argparser.py b/examples/torch/common/argparser.py index f3c0104ec46..261a164abe0 100644 --- a/examples/torch/common/argparser.py +++ b/examples/torch/common/argparser.py @@ -183,6 +183,15 @@ def get_common_argument_parser(): help="Disable compression", action="store_true", ) + + parser.add_argument( + "--find-unused-parameters", + help="For distributed execution mode, if it is true, " + "Parameters that don't receive gradients as part of this graph " + "are preemptively marked as being ready to be reduced. " + "FracBits should turn on this option if freeze_epoch > 0.", + action="store_true", + ) return parser diff --git a/examples/torch/common/execution.py b/examples/torch/common/execution.py index c0e870dce6f..1b5e0f9b0af 100644 --- a/examples/torch/common/execution.py +++ b/examples/torch/common/execution.py @@ -73,13 +73,14 @@ def prepare_model_for_execution(model, config): # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. torch.cuda.set_device(config.current_gpu) - model = torch.nn.parallel.distributed.DistributedDataParallel(model, device_ids=[config.current_gpu]) + model = torch.nn.parallel.distributed.DistributedDataParallel( + model, device_ids=[config.current_gpu], find_unused_parameters=config.find_unused_parameters) model_without_dp = model.module if config.execution_mode == ExecutionMode.DISTRIBUTED: # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set - model = torch.nn.parallel.DistributedDataParallel(model) + model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=config.find_unused_parameters) model_without_dp = model.module if config.execution_mode == ExecutionMode.SINGLE_GPU: diff --git a/examples/torch/requirements.txt b/examples/torch/requirements.txt index b5c8942364c..3152ff9b3df 100644 --- a/examples/torch/requirements.txt +++ b/examples/torch/requirements.txt @@ -8,3 +8,7 @@ returns==0.14 opencv-python>=4.4.0.46 torchvision==0.10.1 # should always match the torch version that is installed via NNCF's setup.py efficientnet_pytorch + +# Please see +# https://stackoverflow.com/questions/70520120/attributeerror-module-setuptools-distutils-has-no-attribute-version +setuptools==59.5.0 diff --git a/nncf/common/statistics.py b/nncf/common/statistics.py index 4f5cff0e008..ba46d084e43 100644 --- a/nncf/common/statistics.py +++ b/nncf/common/statistics.py @@ -104,12 +104,13 @@ def register(self, algorithm_name: str, stats: Statistics): - quantization - filter_pruning - binarization + - fracbits_quantization :param stats: Statistics of the algorithm. """ available_algorithms = [ 'magnitude_sparsity', 'rb_sparsity', 'const_sparsity', - 'quantization', 'filter_pruning', 'binarization' + 'quantization', 'filter_pruning', 'binarization', "fracbits_quantization" ] if algorithm_name not in available_algorithms: raise ValueError('Can not register statistics for the algorithm. ' diff --git a/nncf/config/schemata/experimental_schema.py b/nncf/config/schemata/experimental_schema.py index f55e4150a29..41b7b65367d 100644 --- a/nncf/config/schemata/experimental_schema.py +++ b/nncf/config/schemata/experimental_schema.py @@ -290,11 +290,38 @@ "additionalProperties": False } +######################################################################################################################## +# FracBits Quantization +######################################################################################################################## +FRACBITS_QUANTIZATION_ALGO_NAME_IN_CONFIG = 'fracbits_quantization' +FRACBITS_QUANTIZATION_SCHEMA = copy.deepcopy(QUANTIZATION_SCHEMA) +FRACBITS_QUANTIZATION_SCHEMA['properties']['algorithm']['const'] = FRACBITS_QUANTIZATION_ALGO_NAME_IN_CONFIG +FRACBITS_QUANTIZATION_SCHEMA['properties']['freeze_epoch'] = with_attributes( + NUMBER, description="The number of epoch to freeze fractional bit widths to integers by rounding them.") +FRACBITS_QUANTIZATION_SCHEMA['properties']['loss'] = { + "type": "object", + "properties": { + "type": with_attributes(STRING, description="Type of compression loss. Choose model_size or bitops."), + "compression_rate": with_attributes(NUMBER, description="Target compression rate"), + "criteria": with_attributes(STRING, description="Criteria to measure the distance between the target " + "compression rate and the currrent compression rate. Choose L1 or L2."), + "flip_loss": with_attributes( + BOOLEAN, + description="If true, we compute the compression loss by " + "|1 / target_compression_rate - (current_model_size / target_model_size)|, " + "otherwise, we compute it by " + "|target_compression_rate - (target_model_size / current_model_size)|."), + "alpha": with_attributes(NUMBER, description="Scale multiplier for the compression loss."), + }, + "additionalProperties": False +} + ######################################################################################################################## # All experimental schemas ######################################################################################################################## EXPERIMENTAL_REF_VS_ALGO_SCHEMA = { EXPERIMENTAL_QUANTIZATION_ALGO_NAME_IN_CONFIG: EXPERIMENTAL_QUANTIZATION_SCHEMA, - BOOTSTRAP_NAS_ALGO_NAME_IN_CONFIG: BOOTSTRAP_NAS_SCHEMA + BOOTSTRAP_NAS_ALGO_NAME_IN_CONFIG: BOOTSTRAP_NAS_SCHEMA, + FRACBITS_QUANTIZATION_ALGO_NAME_IN_CONFIG: FRACBITS_QUANTIZATION_SCHEMA } diff --git a/nncf/experimental/torch/fracbits/builder.py b/nncf/experimental/torch/fracbits/builder.py new file mode 100644 index 00000000000..02e06ff3473 --- /dev/null +++ b/nncf/experimental/torch/fracbits/builder.py @@ -0,0 +1,49 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from nncf.experimental.torch.fracbits.controller import FracBitsQuantizationController +from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS +from nncf.torch.compression_method_api import PTCompressionAlgorithmController +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.quantization.algo import QuantizationBuilder +from nncf.torch.quantization.layers import PTQuantizerSetup +from nncf.common.quantization.structs import QuantizationMode +from nncf.experimental.torch.fracbits.quantizer import FracBitsQuantizationMode + + +@PT_COMPRESSION_ALGORITHMS.register('fracbits_quantization') +class FracBitsQuantizationBuilder(QuantizationBuilder): + def _get_quantizer_setup(self, target_model: NNCFNetwork) -> PTQuantizerSetup: + setup = super()._get_quantizer_setup(target_model) + + for q_point in setup.quantization_points.values(): + mode = q_point.qspec.mode + if mode == QuantizationMode.ASYMMETRIC: + q_point.qspec.mode = FracBitsQuantizationMode.ASYMMETRIC + elif mode == QuantizationMode.SYMMETRIC: + q_point.qspec.mode = FracBitsQuantizationMode.SYMMETRIC + else: + raise ValueError(f"qsepc.mode={mode} is unknown.") + + return setup + + def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController: + return FracBitsQuantizationController(model, + self.config, + self._debug_interface, + self._weight_quantizers, + self._non_weight_quantizers, + self._groups_of_adjacent_quantizers, + self._quantizers_input_shapes, + build_time_metric_info=self._build_time_metric_infos, + build_time_range_init_params=self._range_init_params) diff --git a/nncf/experimental/torch/fracbits/controller.py b/nncf/experimental/torch/fracbits/controller.py new file mode 100644 index 00000000000..ce24ebf7bc9 --- /dev/null +++ b/nncf/experimental/torch/fracbits/controller.py @@ -0,0 +1,85 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from contextlib import contextmanager +from typing import Dict, Tuple + +from nncf.common.quantization.structs import NonWeightQuantizerId, QuantizerId, WeightQuantizerId +from nncf.common.statistics import NNCFStatistics +from nncf.config.config import NNCFConfig +from nncf.config.extractors import extract_algo_specific_config +from nncf.experimental.torch.fracbits.statistics import FracBitsStatistics +from nncf.experimental.torch.fracbits.scheduler import FracBitsQuantizationScheduler +from nncf.torch.compression_method_api import PTCompressionLoss +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.quantization.algo import QuantizationController, QuantizationDebugInterface +from nncf.torch.quantization.init_range import PTRangeInitParams +from nncf.torch.quantization.metrics import QuantizationShareBuildTimeInfo +from nncf.torch.quantization.precision_init.adjacent_quantizers import GroupsOfAdjacentQuantizers +from nncf.torch.quantization.structs import NonWeightQuantizerInfo, WeightQuantizerInfo +from nncf.experimental.torch.fracbits.loss import FRACBITS_LOSSES +from nncf.experimental.torch.fracbits.params import FracBitsSchedulerParams, FracBitsLossParams + + +class FracBitsQuantizationController(QuantizationController): + def __init__(self, target_model: NNCFNetwork, + config: NNCFConfig, + debug_interface: QuantizationDebugInterface, + weight_quantizers: Dict[WeightQuantizerId, WeightQuantizerInfo], + non_weight_quantizers: Dict[NonWeightQuantizerId, NonWeightQuantizerInfo], + groups_of_adjacent_quantizers: GroupsOfAdjacentQuantizers, + quantizers_input_shapes: Dict[QuantizerId, Tuple[int]], + build_time_metric_info: QuantizationShareBuildTimeInfo = None, + build_time_range_init_params: PTRangeInitParams = None): + super().__init__(target_model, config, debug_interface, weight_quantizers, non_weight_quantizers, + groups_of_adjacent_quantizers, quantizers_input_shapes, + build_time_metric_info, build_time_range_init_params) + self._set_fracbits_loss(target_model) + self._set_scheduler() + + def _set_fracbits_loss(self, target_model: NNCFNetwork): + algo_config = self._get_algo_config() + loss_config = algo_config.get("loss", {}) + params = FracBitsLossParams.from_config(loss_config) + self._loss: PTCompressionLoss = FRACBITS_LOSSES.get(params.type)(target_model, params) + + def _set_scheduler(self): + algo_config = self._get_algo_config() + params = FracBitsSchedulerParams.from_config(algo_config) + + def _callback(): + self.freeze_bit_widths() + + self._scheduler = FracBitsQuantizationScheduler(freeze_callback=_callback, params=params) + + def _get_algo_config(self) -> Dict: + return extract_algo_specific_config(self.config, algo_name_to_match="fracbits_quantization") + + def freeze_bit_widths(self): + for q in self.all_quantizations.values(): + q.freeze_num_bits() + + def statistics(self, quickly_collected_only=False) -> NNCFStatistics: + @contextmanager + def _base_name_context(): + tmp_name = self._name + self._name = "quantization" + yield self.name + self._name = tmp_name + + with _base_name_context(): + nncf_statistics = super().statistics(quickly_collected_only) + + nncf_statistics.register(self.name, FracBitsStatistics(self._loss.get_state())) + + return nncf_statistics diff --git a/nncf/experimental/torch/fracbits/loss.py b/nncf/experimental/torch/fracbits/loss.py new file mode 100644 index 00000000000..6e83962ecd6 --- /dev/null +++ b/nncf/experimental/torch/fracbits/loss.py @@ -0,0 +1,116 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from numbers import Number +from typing import Dict, Union +import torch +from nncf.common.utils.registry import Registry +from nncf.experimental.torch.fracbits.params import FracBitsLossParams +from nncf.torch.compression_method_api import PTCompressionLoss +from nncf.torch.module_operations import UpdateWeight +from nncf.torch.nncf_network import NNCFNetwork +from torch import nn +from dataclasses import dataclass +from nncf.torch.quantization.layers import BaseQuantizer +from nncf.common.utils.logger import logger as nncf_logger + + +FRACBITS_LOSSES = Registry("fracbits_loss") +EPS = 1e-6 + + +@dataclass +class ModuleQuantizerPair: + module: nn.Module + quantizer: BaseQuantizer + + +@FRACBITS_LOSSES.register("model_size") +class ModelSizeCompressionLoss(PTCompressionLoss): + def __init__(self, model: NNCFNetwork, params: FracBitsLossParams): + super().__init__() + self._model = model + self._compression_rate = torch.FloatTensor([params.compression_rate]) + self._criteria = self._get_criteria(params.criteria) + + self._w_q_pairs: Dict[str, ModuleQuantizerPair] = {} + + for name, module in self._model.named_modules(): + if isinstance(module, UpdateWeight): + parent_name = ".".join(name.split(".")[:-2]) + parent_module = self._model.get_submodule(parent_name) + + self._w_q_pairs[parent_name] = ModuleQuantizerPair(parent_module, module.op) + + with torch.no_grad(): + self._init_model_size = self._get_model_size() + + self._flip_loss = params.flip_loss + self._alpha = params.alpha + + def calculate(self) -> torch.Tensor: + if self._flip_loss: + cur_comp_rate = self._get_frac_model_size() / self._init_model_size + tgt_comp_rate = 1 / self._compression_rate.to(device=cur_comp_rate.device) + else: + cur_comp_rate = self._init_model_size / (self._get_frac_model_size() + EPS) + tgt_comp_rate = self._compression_rate.to(device=cur_comp_rate.device) + + return self._alpha * self._criteria(cur_comp_rate, tgt_comp_rate) + + def _get_criteria(self, criteria) -> nn.modules.loss._Loss: + if criteria == "L1": + return nn.L1Loss() + if criteria == "L2": + return nn.MSELoss() + raise RuntimeError(f"Unknown criteria = {criteria}.") + + @staticmethod + def _get_module_size(module: nn.Module, num_bits: Union[int, torch.Tensor]) -> Union[torch.Tensor, Number]: + if isinstance(module, (nn.modules.conv._ConvNd, nn.Linear)): # pylint: disable=protected-access + return module.weight.shape.numel() * num_bits + nncf_logger.warning("module={module} is not supported by ModelSizeCompressionLoss. Skip it.") + return 0. + + def _get_frac_model_size(self) -> torch.Tensor: + return sum([self._get_module_size(pair.module, pair.quantizer.frac_num_bits) + for pair in self._w_q_pairs.values()]) + + def _get_model_size(self) -> Number: + return sum([self._get_module_size(pair.module, pair.quantizer.num_bits) for pair in self._w_q_pairs.values()]) + + @torch.no_grad() + def get_state(self) -> Dict[str, Number]: + curr_model_size = self._get_model_size() + frac_model_size = self._get_frac_model_size() + + states = { + "current_model_size": curr_model_size, + "fractional_model_size": self._init_model_size / (frac_model_size + EPS), + "compression_rate": self._init_model_size / (curr_model_size + EPS) + } + + for name, pair in self._w_q_pairs.items(): + states[f"frac_bits/{name}"] = pair.quantizer.frac_num_bits.item() + + return states + + +@FRACBITS_LOSSES.register("bitops") +class BitOpsCompressionLoss(PTCompressionLoss): + def calculate(self) -> torch.Tensor: + raise NotImplementedError() + + @torch.no_grad() + def get_state(self) -> Dict[str, Number]: + raise NotImplementedError() diff --git a/nncf/experimental/torch/fracbits/params.py b/nncf/experimental/torch/fracbits/params.py new file mode 100644 index 00000000000..977be75cb6c --- /dev/null +++ b/nncf/experimental/torch/fracbits/params.py @@ -0,0 +1,36 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from dataclasses import dataclass, fields +from typing import Dict + + +class FracBitsParamsBase: + @classmethod + def from_config(cls, config: Dict) -> "FracBitsParamsBase": + attr_names = {f.name for f in fields(cls)} + return cls(**{k: v for k, v in config.items() if k in attr_names}) + + +@dataclass +class FracBitsLossParams(FracBitsParamsBase): + type: str = "model_size" + compression_rate: float = 1.5 + criteria: str = "L1" + flip_loss: bool = False + alpha: float = 10.0 + + +@dataclass +class FracBitsSchedulerParams(FracBitsParamsBase): + freeze_epoch: int = -1 diff --git a/nncf/experimental/torch/fracbits/quantizer.py b/nncf/experimental/torch/fracbits/quantizer.py new file mode 100644 index 00000000000..0e30d344ecb --- /dev/null +++ b/nncf/experimental/torch/fracbits/quantizer.py @@ -0,0 +1,220 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +# Reference: Yang, Linjie, and Qing Jin. "Fracbits: Mixed precision quantization via fractional bit-widths." +# Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 35. No. 12. 2021. + +from typing import Dict +import torch + +from nncf.torch.layer_utils import COMPRESSION_MODULES, CompressionParameter +from nncf.torch.quantization.layers import ( + QUANTIZATION_MODULES, AsymmetricQuantizer, PTQuantizerSpec, SymmetricQuantizer) +from nncf.torch.quantization.quantize_functions import asymmetric_quantize, symmetric_quantize +from nncf.torch.utils import no_jit_trace +from nncf.common.quantization.structs import QuantizationMode + + +class FracBitsQuantizationMode(QuantizationMode): + SYMMETRIC = 'fracbits_symmetric' + ASYMMETRIC = 'fracbits_asymmetric' + + +@COMPRESSION_MODULES.register() +@QUANTIZATION_MODULES.register(FracBitsQuantizationMode.SYMMETRIC) +class FracBitsSymmetricQuantizer(SymmetricQuantizer): + def __init__(self, qspec: PTQuantizerSpec): + super().__init__(qspec) + self._min_num_bits = int(0.5 * qspec.num_bits) + self._max_num_bits = int(1.5 * qspec.num_bits) + self._num_bits = CompressionParameter(torch.FloatTensor([qspec.num_bits]), requires_grad=True, + compression_lr_multiplier=qspec.compression_lr_multiplier) + + @property + def frac_num_bits(self): + return torch.clamp(self._num_bits, self._min_num_bits, self._max_num_bits) + + @property + def num_bits(self): + if self._num_bits.dtype == torch.int32: + return super().num_bits + + with no_jit_trace(): + return self.frac_num_bits.round().int().item() + + @num_bits.setter + def num_bits(self, num_bits: int): + if num_bits < self._min_num_bits or num_bits > self._max_num_bits: + raise RuntimeError( + f"{num_bits} should be in [{self._min_num_bits}, {self._max_num_bits}]") + self._num_bits.fill_(num_bits) + + @property + def is_num_bits_frozen(self) -> bool: + return not self._num_bits.requires_grad + + def set_min_max_num_bits(self, min_num_bits: int, max_num_bits: int): + if min_num_bits >= max_num_bits: + raise ValueError( + f"min_num_bits({min_num_bits}) >= max_num_bits({max_num_bits})") + self._min_num_bits = min_num_bits + self._max_num_bits = max_num_bits + + def unfreeze_num_bits(self) -> None: + self._num_bits.requires_grad_(True) + + def freeze_num_bits(self) -> None: + self._num_bits.requires_grad_(False) + super().set_level_ranges() + + def enable_gradients(self): + super().enable_gradients() + self.unfreeze_num_bits() + + def disable_gradients(self): + super().disable_gradients() + self.freeze_num_bits() + + def _quantize_with_n_bits(self, x, num_bits, execute_traced_op_as_identity: bool = False): + scaled_num_bits = 1 if self._half_range else 0 + + level_low, level_high, levels = self.calculate_level_ranges( + num_bits - scaled_num_bits, self.signed) + + return symmetric_quantize(x, levels, level_low, level_high, self.scale, self.eps, + skip=execute_traced_op_as_identity) + + def quantize(self, x, execute_traced_op_as_identity: bool = False): + if self.is_num_bits_frozen: + return super().quantize(x, execute_traced_op_as_identity) + + fl_num_bits = self.frac_num_bits.floor().int().item() + ce_num_bits = fl_num_bits + 1 + + fl_q = self._quantize_with_n_bits( + x, fl_num_bits, execute_traced_op_as_identity) + ce_q = self._quantize_with_n_bits( + x, ce_num_bits, execute_traced_op_as_identity) + + return (self.frac_num_bits - fl_num_bits) * ce_q + (ce_num_bits - self.frac_num_bits) * fl_q + + def get_trainable_params(self) -> Dict[str, torch.Tensor]: + return {self.SCALE_PARAM_NAME: self.scale.detach(), "num_bits": self.frac_num_bits.detach()} + + def _prepare_export_quantization(self, x: torch.Tensor): + self.freeze_num_bits() + return super()._prepare_export_quantization(x) + + @torch.no_grad() + def get_input_range(self): + self.set_level_ranges() + input_low, input_high = self._get_input_low_input_high( + self.scale, self.level_low, self.level_high, self.eps) + return input_low, input_high + + +@COMPRESSION_MODULES.register() +@QUANTIZATION_MODULES.register(FracBitsQuantizationMode.ASYMMETRIC) +class FracBitsAsymmetricQuantizer(AsymmetricQuantizer): + def __init__(self, qspec: PTQuantizerSpec): + super().__init__(qspec) + self._min_num_bits = int(0.5 * qspec.num_bits) + self._max_num_bits = int(1.5 * qspec.num_bits) + self._num_bits = CompressionParameter(torch.FloatTensor([qspec.num_bits]), requires_grad=True, + compression_lr_multiplier=qspec.compression_lr_multiplier) + + @property + def frac_num_bits(self): + return torch.clamp(self._num_bits, self._min_num_bits, self._max_num_bits) + + @property + def num_bits(self) -> int: + if self._num_bits.dtype == torch.int32: + return super().num_bits + + with no_jit_trace(): + return self.frac_num_bits.round().int().item() + + @num_bits.setter + def num_bits(self, num_bits: int): + if num_bits < self._min_num_bits or num_bits > self._max_num_bits: + raise RuntimeError( + f"{num_bits} should be in [{self._min_num_bits}, {self._max_num_bits}]") + self._num_bits.fill_(num_bits) + + @property + def is_num_bits_frozen(self) -> bool: + return not self._num_bits.requires_grad + + def set_min_max_num_bits(self, min_num_bits: int, max_num_bits: int): + if min_num_bits >= max_num_bits: + raise ValueError( + f"min_num_bits({min_num_bits}) >= max_num_bits({max_num_bits})") + self._min_num_bits = min_num_bits + self._max_num_bits = max_num_bits + + def unfreeze_num_bits(self) -> None: + self._num_bits.requires_grad_(True) + + def freeze_num_bits(self) -> None: + self._num_bits.requires_grad_(False) + super().set_level_ranges() + + def enable_gradients(self): + super().enable_gradients() + self.unfreeze_num_bits() + + def disable_gradients(self): + super().disable_gradients() + self.freeze_num_bits() + + def _quantize_with_n_bits(self, x, num_bits, execute_traced_op_as_identity: bool = False): + scaled_num_bits = 1 if self._half_range else 0 + + level_low, level_high, levels = self.calculate_level_ranges( + num_bits - scaled_num_bits) + + return asymmetric_quantize(x, levels, level_low, level_high, self.input_low, self.input_range, self.eps, + skip=execute_traced_op_as_identity) + + def quantize(self, x, execute_traced_op_as_identity: bool = False): + if self.is_num_bits_frozen: + return super().quantize(x, execute_traced_op_as_identity) + + fl_num_bits = self.frac_num_bits.floor().int().item() + ce_num_bits = fl_num_bits + 1 + + fl_q = self._quantize_with_n_bits( + x, fl_num_bits, execute_traced_op_as_identity) + ce_q = self._quantize_with_n_bits( + x, ce_num_bits, execute_traced_op_as_identity) + + return (self.frac_num_bits - fl_num_bits) * ce_q + (ce_num_bits - self.frac_num_bits) * fl_q + + def get_trainable_params(self) -> Dict[str, torch.Tensor]: + return {self.INPUT_LOW_PARAM_NAME: self.input_low.detach(), + self.INPUT_RANGE_PARAM_NAME: self.input_range.detach(), + "num_bits": self._num_bits.detach()} + + def _prepare_export_quantization(self, x: torch.Tensor): + self.freeze_num_bits() + return super()._prepare_export_quantization(x) + + @torch.no_grad() + def get_input_range(self): + self.set_level_ranges() + input_low, input_high = self._get_input_low_input_high(self.input_range, + self.input_low, + self.levels, + self.eps) + return input_low, input_high diff --git a/nncf/experimental/torch/fracbits/scheduler.py b/nncf/experimental/torch/fracbits/scheduler.py new file mode 100644 index 00000000000..b9e8a821c7d --- /dev/null +++ b/nncf/experimental/torch/fracbits/scheduler.py @@ -0,0 +1,34 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Callable + +from nncf.common.schedulers import BaseCompressionScheduler +from nncf.common.utils.logger import logger as nncf_logger +from nncf.experimental.torch.fracbits.params import FracBitsSchedulerParams + + +class FracBitsQuantizationScheduler(BaseCompressionScheduler): + def __init__(self, freeze_callback: Callable, params: FracBitsSchedulerParams): + super().__init__() + self._freeze_epoch = params.freeze_epoch + self._freeze_callback = freeze_callback + + if self._freeze_epoch < 0: + nncf_logger.warning(f"freeze_epoch={self._freeze_epoch} is less than 0. Don't freeze fractional bit widths") + + def epoch_step(self, next_epoch=None): + super().epoch_step(next_epoch) + if self._current_epoch == self._freeze_epoch: + nncf_logger.info(f"Current epoch is {self._current_epoch}. Freeze fractional bit widths.") + self._freeze_callback() diff --git a/nncf/experimental/torch/fracbits/statistics.py b/nncf/experimental/torch/fracbits/statistics.py new file mode 100644 index 00000000000..a8e1951571d --- /dev/null +++ b/nncf/experimental/torch/fracbits/statistics.py @@ -0,0 +1,35 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from numbers import Number +from typing import Dict +from nncf.api.statistics import Statistics + +from nncf.common.utils.tensorboard import convert_to_dict + + +class FracBitsStatistics(Statistics): + def __init__(self, states: Dict[str, Number]) -> None: + super().__init__() + self.data = states + + def to_str(self) -> str: + return str(self.data) + + +@convert_to_dict.register(FracBitsStatistics) +def _convert_to_dict(stats: FracBitsStatistics, algorithm_name: str): + tensorboard_stats = { + algorithm_name + "/" + k: v for k, v in stats.data.items() + } + return tensorboard_stats diff --git a/nncf/torch/quantization/algo.py b/nncf/torch/quantization/algo.py index 08002b0b5a3..045d8217e98 100644 --- a/nncf/torch/quantization/algo.py +++ b/nncf/torch/quantization/algo.py @@ -539,7 +539,7 @@ def _parse_init_params(self): self._algo_config.get('initializer', {})) def _parse_range_init_params(self) -> Optional[PTRangeInitParams]: - range_init_params = extract_range_init_params(self.config) + range_init_params = extract_range_init_params(self.config, algorithm_name=self.name) return PTRangeInitParams(**range_init_params) if range_init_params is not None else None def _parse_precision_init_params(self, initializer_config: Dict) -> Tuple[str, BasePrecisionInitParams]: @@ -1236,7 +1236,7 @@ def initialize(self, model: NNCFNetwork) -> None: bn_adapt_params = self._parse_bn_adapt_params() if bn_adapt_params is not None: bn_adaptation = BatchnormAdaptationAlgorithm( - **extract_bn_adaptation_init_params(self.config, 'quantization')) + **extract_bn_adaptation_init_params(self.config, algo_name=self.name)) bn_adaptation.run(model) @@ -1436,7 +1436,7 @@ def statistics(self, quickly_collected_only=False) -> NNCFStatistics: stats = collector.collect() nncf_stats = NNCFStatistics() - nncf_stats.register('quantization', stats) + nncf_stats.register(self.name, stats) return nncf_stats diff --git a/tests/torch/experimental/fracbits/conftest.py b/tests/torch/experimental/fracbits/conftest.py new file mode 100644 index 00000000000..e49897dd743 --- /dev/null +++ b/tests/torch/experimental/fracbits/conftest.py @@ -0,0 +1,102 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import pytest +import torch +from torch import nn + +from nncf.torch.quantization.layers import PTQuantizerSpec +from tests.torch.helpers import BasicConvTestModel, register_bn_adaptation_init_args +from tests.torch.quantization.test_quantization_helpers import get_empty_config + + +def set_manual_seed(): + torch.manual_seed(3003) + + +@pytest.fixture(scope="function", name="linear_problem") +def fxt_linear_problem(num_bits: int = 4, sigma: float = 0.2): + set_manual_seed() + + levels = 2 ** num_bits + w = 1 / levels * (torch.randint(0, levels, size=[100, 10]) - levels // 2) + x = torch.randn([1000, 10]) + y = w.mm(x.t()) + y += sigma * torch.randn_like(y) + + return w, x, y, num_bits, sigma + + +@pytest.fixture() +def qspec(request): + return PTQuantizerSpec(num_bits=8, + mode=request.param, + signedness_to_force=None, + scale_shape=(1, 1), + narrow_range=False, + half_range=False, + logarithm_scale=False) + + +@pytest.fixture(name="config") +def fxt_config(model_size: int = 4): + new_config = get_empty_config(model_size) + + new_config["compression"] = { + "algorithm": "fracbits_quantization", + "initializer": { + "range": { + "num_init_samples": 0 + } + }, + "freeze_epoch": -1, + "loss": { + "type": "model_size", + "compression_rate": 1.5, + "criteria": "L1", + "flip_loss": False, + "alpha": 1.0 + } + } + register_bn_adaptation_init_args(new_config) + return new_config + + +@pytest.fixture(name="conv_model") +def fxt_conv_model(): + set_manual_seed() + return BasicConvTestModel() + + +@pytest.fixture() +def lp_with_config_and_model(linear_problem, config): + w, x, y, num_bits, sigma = linear_problem + x = x.unsqueeze(0) + y = y.unsqueeze(0).transpose(-2, -1) + + model = nn.Linear(in_features=w.shape[1], out_features=w.shape[0], bias=False) + with torch.no_grad(): + model.weight.copy_(w) + + config["input_info"] = [{"sample_size": list(x.shape)}] + return model, x, y, num_bits, sigma, nn.MSELoss(), config + + +@pytest.fixture +def conv_model_with_input_output(conv_model): + with torch.no_grad(): + x = torch.randn([1, 1, 4, 4]) + y = conv_model(x) + y += 0.1 * torch.randn_like(y) + + return conv_model, x, y diff --git a/tests/torch/experimental/fracbits/test_builder.py b/tests/torch/experimental/fracbits/test_builder.py new file mode 100644 index 00000000000..1047008e5df --- /dev/null +++ b/tests/torch/experimental/fracbits/test_builder.py @@ -0,0 +1,122 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from copy import deepcopy +import torch +from nncf.common.statistics import NNCFStatistics +from nncf.experimental.torch.fracbits.loss import ModelSizeCompressionLoss + +from nncf.torch.compression_method_api import PTCompressionLoss +from nncf.torch.dynamic_graph.scope import Scope, ScopeElement +from nncf.torch.model_creation import create_compression_algorithm_builder +from nncf.torch.module_operations import UpdateInputs, UpdateWeight +from nncf.torch.utils import get_all_modules_by_type + +from tests.torch.helpers import create_compressed_model_and_algo_for_test +from nncf.experimental.torch.fracbits.builder import FracBitsQuantizationBuilder +from nncf.experimental.torch.fracbits.quantizer import FracBitsSymmetricQuantizer + +#pylint: disable=protected-access + + +def test_create_builder(config): + builder = create_compression_algorithm_builder(config) + assert isinstance(builder, FracBitsQuantizationBuilder) + + +def test_can_load_quant_algo_with_defaults(config, conv_model): + quant_model, _ = create_compressed_model_and_algo_for_test(deepcopy(conv_model), config) + + model_conv = get_all_modules_by_type(conv_model, 'Conv2d') + quant_model_conv = get_all_modules_by_type( + quant_model.get_nncf_wrapped_model(), 'NNCFConv2d') + assert len(model_conv) == len(quant_model_conv) + + for module_scope, _ in model_conv.items(): + true_quant_scope: Scope = deepcopy(module_scope) + true_quant_scope.pop() + true_quant_scope.push(ScopeElement('NNCFConv2d', 'conv')) + assert true_quant_scope in quant_model_conv.keys() + + store = [] + for op in quant_model_conv[true_quant_scope].pre_ops.values(): + if isinstance(op, (UpdateInputs, UpdateWeight)) and isinstance(op.operand, FracBitsSymmetricQuantizer): + assert op.__class__.__name__ not in store + store.append(op.__class__.__name__) + assert UpdateWeight.__name__ in store + + +def test_quant_loss(config, conv_model): + _, compression_ctrl = create_compressed_model_and_algo_for_test(conv_model, config) + + loss = compression_ctrl.loss + assert isinstance(loss, PTCompressionLoss) + + loss_value = compression_ctrl.loss.calculate() + assert isinstance(loss_value, torch.Tensor) + assert loss_value.grad_fn is not None + + # Check whether bit_width gradient is not None + loss_value.backward() + for qinfo in compression_ctrl.weight_quantizers.values(): + q = qinfo.quantizer_module_ref + assert q._num_bits.grad.data is not None + + +def test_quant_loss_params(config, conv_model): + _, compression_ctrl = create_compressed_model_and_algo_for_test(conv_model, config) + + loss: ModelSizeCompressionLoss = compression_ctrl.loss + assert isinstance(loss, ModelSizeCompressionLoss) + + loss_config = config["compression"]["loss"] + assert isinstance(loss._criteria, torch.nn.L1Loss) + assert loss._alpha == loss_config["alpha"] + assert loss._flip_loss == loss_config["flip_loss"] + assert loss._compression_rate.item() == loss_config["compression_rate"] + + +def test_e2e_quant_loss(config, conv_model_with_input_output): + conv_model, x, y = conv_model_with_input_output + criterion = torch.nn.MSELoss() + + quant_model, compression_ctrl = create_compressed_model_and_algo_for_test(conv_model, config) + + optimizer = torch.optim.SGD(quant_model.parameters(), lr=1e-1) + + for i in range(500): + optimizer.zero_grad() + target_loss = criterion(quant_model(x), y) + comp_loss = compression_ctrl.loss.calculate() + + loss = target_loss + comp_loss + loss.backward() + optimizer.step() + + if i == 300: + compression_ctrl.freeze_bit_widths() + + target_comp_rate = config["compression"]["loss"]["compression_rate"] + for qinfo in compression_ctrl.weight_quantizers.values(): + q = qinfo.quantizer_module_ref + assert q.num_bits <= int(8 / target_comp_rate) + + +def test_statistics(config, conv_model): + _, ctrl = create_compressed_model_and_algo_for_test(conv_model, config) + + stats: NNCFStatistics = ctrl.statistics() + assert stats.quantization is not None + + dict_stats = dict(stats) + assert dict_stats[ctrl.name] is not None diff --git a/tests/torch/experimental/fracbits/test_quantizer.py b/tests/torch/experimental/fracbits/test_quantizer.py new file mode 100644 index 00000000000..cd6df5f636e --- /dev/null +++ b/tests/torch/experimental/fracbits/test_quantizer.py @@ -0,0 +1,86 @@ +""" + Copyright (c) 2022 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import pytest +import torch +from torch import nn + +from nncf.common.utils.logger import logger as nncf_logger +from nncf.experimental.torch.fracbits.quantizer import ( + FracBitsAsymmetricQuantizer, FracBitsSymmetricQuantizer, FracBitsQuantizationMode) + + +@pytest.mark.parametrize("add_bitwidth_loss", [True, False]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("qspec", + [FracBitsQuantizationMode.ASYMMETRIC, FracBitsQuantizationMode.SYMMETRIC], indirect=["qspec"]) +def test_quantization(linear_problem, qspec, device, add_bitwidth_loss): + """ + Test quantization for the simple linear problem. + The weight is filled with the random integer in range with [-2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1], + then scaled with 1 / (bit-width). Thus, it will finally be in [-0.5, 0.5]. + We initiate the quantizers input_low and input_high smaller than [-0.5, 0.5] by multiplying 0.1 to both limits. + Let SGD optimizer to learn quantizer parameters with a MSE loss for the linear model. + Check whether input_low and input_high is expanded to [-0.5, 0.5] to compensate quantization errors, + and the MSE loss is also minimized. If we add target bit_bidth loss, + we have to check whether our quantizer's learnable bit_width also goes to the target bit_width. + """ + w, x, y, bit_width, sigma = linear_problem + + w, x, y = w.to(device=device), x.to(device=device), y.to(device=device) + + quant = FracBitsAsymmetricQuantizer( + qspec) if qspec.mode == FracBitsQuantizationMode.ASYMMETRIC else FracBitsSymmetricQuantizer(qspec) + + init_input_low = torch.FloatTensor([w.min() * 0.1]) + init_input_high = torch.FloatTensor([w.max() * 0.1]) + + quant.apply_minmax_init(init_input_low, init_input_high) + quant = quant.to(w.device) + criteria = nn.MSELoss() + + optim = torch.optim.SGD(quant.parameters(), lr=1e-1) + + for _ in range(100): + optim.zero_grad() + loss = criteria(y, quant(w).mm(x.t())) + + if add_bitwidth_loss: + loss += criteria(bit_width * + torch.ones_like(quant.frac_num_bits), quant.frac_num_bits) + + loss.backward() + optim.step() + + eps = 0.05 + ub_mse_loss = 1.1 * (sigma ** 2) + ub_left_q_w = w.min() + eps + lb_right_q_w = w.max() - eps + + with torch.no_grad(): + loss = criteria(y, quant(w).mm(x.t())).item() + nncf_logger.debug( + f"loss={loss:.3f} should be lower than ub_mse_loss={ub_mse_loss:.3f}.") + assert loss < ub_mse_loss + + left_q_w, right_q_w = quant.get_input_range() + left_q_w, right_q_w = left_q_w.item(), right_q_w.item() + + nncf_logger.debug(f"[left_q_w, right_q_w]^C [{left_q_w:.3f}, {right_q_w:.3f}]^C should be included in " + f"[ub_left_q_w, lb_right_q_w]^C = [{ub_left_q_w:.3f}, {lb_right_q_w:.3f}]^C.") + + assert left_q_w < ub_left_q_w + assert lb_right_q_w < right_q_w + + if add_bitwidth_loss: + assert quant.num_bits == bit_width