Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
190f9d5
init
anzr299 Sep 22, 2025
c52fcca
fixes
anzr299 Sep 22, 2025
4e56cb5
add message for unsupported external quantizers
anzr299 Sep 22, 2025
9651ceb
add algorithm
anzr299 Sep 22, 2025
14daeb5
impotr openvino quantizer from nncf instead of executorch
anzr299 Sep 22, 2025
3746815
Add observers and openvino quantizer to nncf
anzr299 Sep 22, 2025
0815dc5
fix
anzr299 Sep 22, 2025
1b8d940
minor fix
anzr299 Sep 22, 2025
7d35374
fix
anzr299 Sep 22, 2025
427ebc2
fix some more bugs; observers was importing from torchao. causing mis…
anzr299 Sep 22, 2025
24dbfb6
add compress pt2e to init
anzr299 Sep 22, 2025
4bb8c1a
fix quantizer init file. Remove extra code.
anzr299 Sep 22, 2025
8902842
small fix for the big problem:)
anzr299 Sep 23, 2025
3842538
fix quantizer preset definition
anzr299 Sep 23, 2025
2e70c2e
fix openvino quantizer for ptq. call _algo instead of legacy _min_max…
anzr299 Sep 23, 2025
b1c9aad
fix quantizer defaults
anzr299 Sep 23, 2025
33fe01c
microfix
anzr299 Sep 23, 2025
d8e1006
precommit fix
anzr299 Sep 23, 2025
88a8472
revert openvino quantizer to old
anzr299 Sep 23, 2025
7a8e51a
create ovquantizer in executorch dir
anzr299 Sep 23, 2025
fed5052
update executorch quantizer location.
anzr299 Sep 23, 2025
2866473
check if openvino quantizer has weight compression in openvino adapter
anzr299 Sep 23, 2025
7171d56
review comments
anzr299 Sep 24, 2025
3e3b067
revert ignored scope changes; make sensitivity metric None to check i…
anzr299 Sep 24, 2025
5b7b210
precommit fix
anzr299 Sep 24, 2025
71a479f
pre commit format
anzr299 Sep 24, 2025
b24a59c
rename executorch quantizer to test_quantizer
anzr299 Sep 24, 2025
d12225a
fix last precommit
anzr299 Sep 24, 2025
9870ee2
remove unused mypy ignore
anzr299 Sep 24, 2025
8015629
get the mode as struct
anzr299 Sep 24, 2025
0804218
fix algorithm
anzr299 Sep 24, 2025
1f1fda3
remove quantizer and observers from nncf. Instead import from executorch
anzr299 Sep 24, 2025
623ce46
rework wc algorithm so that get_weight_comrpession_params becomes mor…
anzr299 Oct 1, 2025
d14a6eb
fix bugs; use sensitivity metric instead of mixed precision algo
anzr299 Oct 1, 2025
e91b455
update algorithm with new reworking
anzr299 Oct 6, 2025
448bf84
changes
anzr299 Oct 6, 2025
8e23572
review changes
anzr299 Oct 6, 2025
36ddf53
change WeightsCompressionPT2E to ExperimentalWeightsCompression
anzr299 Oct 7, 2025
07b730b
change ExperimentalWeightsCompression to WeightsCompression
anzr299 Oct 7, 2025
d5dd422
add comments
anzr299 Oct 7, 2025
076a76b
add typehints
anzr299 Oct 7, 2025
2ce9eec
add docstrings
anzr299 Oct 7, 2025
1bebf3e
add typehint for quantize pt2e
anzr299 Oct 7, 2025
ea81cfd
Merge branch 'openvinotoolkit:develop' into an/fx/compress_pt2e
anzr299 Oct 7, 2025
e82920f
return original develop branch changes
anzr299 Oct 7, 2025
82cc10b
update typehints and docs
anzr299 Oct 7, 2025
beae508
format
anzr299 Oct 7, 2025
8bd95df
update type hinting of openvino adapter
anzr299 Oct 7, 2025
aac9d3f
add test
anzr299 Oct 10, 2025
4278cfd
update reference graphs; use more samples for calibration dataset. Th…
anzr299 Oct 10, 2025
6fd5216
remove groupsize values as return statement from get_weight_compressi…
anzr299 Oct 10, 2025
118b611
update algorithm
anzr299 Oct 13, 2025
e9f3cd4
change WeightCompression to OriginalWeightCompression in experimental…
anzr299 Oct 13, 2025
a969e58
update docstrings as discussed offline
anzr299 Oct 13, 2025
71d0597
revert torchaoadapter code
anzr299 Oct 13, 2025
bf671ff
precommit fix
anzr299 Oct 14, 2025
5f1c2de
rename test_quantizer to test_quantizer_compression.py
anzr299 Oct 14, 2025
6f81879
review changes
anzr299 Oct 14, 2025
eb0ff16
review changes
anzr299 Oct 14, 2025
8afeb9d
precommit fix
anzr299 Oct 14, 2025
f491c8d
update quantizer test to include scales; remve sensitivity metric fro…
anzr299 Oct 15, 2025
b9f3eff
update test and references
anzr299 Oct 15, 2025
09dabf6
minor
anzr299 Oct 15, 2025
68316a5
add workflow for executorch test
anzr299 Oct 15, 2025
58b8992
update workflow and makefile
anzr299 Oct 15, 2025
e7bae1f
update execiutorch test requirements.
anzr299 Oct 15, 2025
4b0d8ea
Merge branch 'openvinotoolkit:develop' into an/fx/compress_pt2e
anzr299 Oct 15, 2025
d4da34f
fix precommit
anzr299 Oct 15, 2025
2b91658
override constraint in executorch workflow
anzr299 Oct 15, 2025
93c3f19
minor fix
anzr299 Oct 15, 2025
932b296
update workflow for fix
anzr299 Oct 15, 2025
67ab135
update workflow file
anzr299 Oct 15, 2025
a23acaf
install executorch after pytorch
anzr299 Oct 15, 2025
6462284
install torch nightly
anzr299 Oct 15, 2025
cf7e8d3
update requirements and revert workflow changes
anzr299 Oct 15, 2025
a07dc07
fix minor workflow file issue
anzr299 Oct 15, 2025
0506bca
install with no build isolation
anzr299 Oct 15, 2025
f8675ad
include executorch requirements
anzr299 Oct 15, 2025
52a7d5a
include openvino in requirements
anzr299 Oct 15, 2025
9e02948
fix
anzr299 Oct 15, 2025
a578fce
fix
anzr299 Oct 15, 2025
8ae6a80
update requirements
anzr299 Oct 15, 2025
c7210b8
add conftest and __init__
anzr299 Oct 15, 2025
2f8b296
use older pytorch commit
anzr299 Oct 15, 2025
75ccdcb
change torch versions to 2.10.0.dev20250922+cpu
anzr299 Oct 16, 2025
75cc255
install executorch directly from requirements txt
anzr299 Oct 16, 2025
3cdfe74
comments
anzr299 Oct 16, 2025
e4f9286
seperate executorch installation
anzr299 Oct 16, 2025
009c587
precommit fix
anzr299 Oct 16, 2025
6e379c8
conftest precommit
anzr299 Oct 16, 2025
e45f796
update ref location for executorch
anzr299 Oct 16, 2025
f2ece8c
define ratio in compress_pt2e API and not Quantizer itself; Update test
anzr299 Oct 16, 2025
387d69c
Apply suggestion from @daniil-lyakhov
anzr299 Oct 16, 2025
4ace0df
Merge branch 'openvinotoolkit:develop' into an/fx/compress_pt2e
anzr299 Nov 5, 2025
00c8897
pre-commit fix
anzr299 Nov 5, 2025
dd34b9b
Apply suggestions from code review
anzr299 Nov 13, 2025
b9509bc
review changes
anzr299 Nov 13, 2025
d960d9a
add mypy; review changes
anzr299 Nov 13, 2025
758bd67
precommit fix; seperate mixed precision algorithm application from th…
anzr299 Nov 13, 2025
193c404
add credit for transformers
anzr299 Nov 13, 2025
8807c10
all optional arguemnts are keyword-only
anzr299 Nov 13, 2025
d86a90a
review changes
anzr299 Nov 13, 2025
f2f01f2
review changes
anzr299 Nov 13, 2025
0dc7f64
executorch fix
anzr299 Nov 14, 2025
f6d3739
remove extra comments
anzr299 Nov 14, 2025
7ffa572
fix duplication of tests
anzr299 Nov 14, 2025
64803c3
avoid square complexity
anzr299 Nov 14, 2025
9d25bed
review changes
anzr299 Nov 14, 2025
859328e
change private methods to public which are used externally
anzr299 Nov 14, 2025
33a4b77
review changes
anzr299 Nov 14, 2025
ffd601d
review changes
anzr299 Nov 14, 2025
8484f1a
remove extra function
anzr299 Nov 14, 2025
5aadf1c
minor fix
anzr299 Nov 14, 2025
f93eed2
remove private var assignments from experimeental WC algo init
anzr299 Nov 17, 2025
a5bb632
review changes
anzr299 Nov 17, 2025
5665bed
add description for MP and validation methods in algo
anzr299 Nov 17, 2025
ee86a20
review changes
anzr299 Nov 17, 2025
b31a1ab
update docstring
anzr299 Nov 17, 2025
c6557f6
fix error
anzr299 Nov 17, 2025
ee9a2de
review changes
anzr299 Nov 17, 2025
b1fcfa9
review changes
anzr299 Nov 17, 2025
6c56b91
remove extra kwarg
anzr299 Nov 17, 2025
e5ea21b
fix executorch test
anzr299 Nov 17, 2025
7bf3c78
review changes
anzr299 Nov 17, 2025
f2d9968
review changes
anzr299 Nov 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions .github/workflows/call_precommit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,48 @@ jobs:
env:
NUM_WORKERS: 4

executorch:
timeout-minutes: 40
runs-on: ubuntu-latest-8-cores
defaults:
run:
shell: bash
env:
DEBIAN_FRONTEND: noninteractive
steps:
- name: Install dependencies
run : |
sudo apt-get update
sudo apt-get --assume-yes install gcc g++ build-essential ninja-build libgl1-mesa-dev libglib2.0-0
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
lfs: true
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: ${{ inputs.python_version }}
- name: Runner info
continue-on-error: true
run: |
cat /etc/*release
cat /proc/cpuinfo
- name: Override constraints
if: ${{ inputs.override_requirements != '' }}
run: python .github/scripts/override_constraints.py "${{ inputs.override_requirements }}"
shell: bash
- name: Install NNCF and test requirements
run: |
pip install . -r tests/executorch/requirements.txt
# Executorch
# Editable install due to https://github.com/pytorch/executorch/issues/6475
pip install --no-build-isolation -e git+https://github.com/anzr299/executorch.git@an/quantizer_nncf_pt2e_support#egg=executorch
- name: Print installed modules
run: pip list
- name: Run PyTorch precommit test scope
run: |
make test-executorch
env:
NUM_WORKERS: 4

pytorch-cuda:
timeout-minutes: 40
runs-on: aks-linux-4-cores-28gb-gpu-tesla-t4
Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ test-torch-cpu:
test-torch-cuda:
pytest ${COVERAGE_ARGS} tests/torch -ra -m "cuda and not weekly and not nightly and not models_hub and not legacy" --junitxml ${JUNITXML_PATH}

test-executorch:
pytest ${COVERAGE_ARGS} tests/executorch --junitxml ${JUNITXML_PATH}

test-torch-nightly:
pytest ${COVERAGE_ARGS} tests/torch -m "nightly or legacy" --junitxml ${JUNITXML_PATH} $(DATA_ARG)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Optional

import torch

from nncf import AdvancedCompressionParameters
from nncf import CompressionFormat
from nncf import CompressWeightsMode
from nncf import Dataset
from nncf import SensitivityMetric
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.logging import nncf_logger
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.experimental.quantization.quantizer import Quantizer
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression as OriginalWeightCompression


class WeightsCompression(Algorithm):
"""
Post-training Weight Compression algorithm implementation.

Compresses weights of Linear and Embedding layers to 8-bit integer or
to 4-bit integer/float depending on mode, ratio and group size.
"""

def __init__(
self,
quantizer: Quantizer,
ratio: float,
subset_size: int,
awq: bool,
scale_estimation: bool,
gptq: bool,
lora_correction: bool,
sensitivity_metric: SensitivityMetric,
compression_format: CompressionFormat,
advanced_parameters: AdvancedCompressionParameters,
) -> torch.fx.GraphModule:
"""
:param quantizer: Quantizer to use in WeightCompression algorithm.
:param ratio: the ratio between primary and backup precisions (e.g. 0.9 means 90% of layers specified as
`ratio_defining_params` by the quantizer are quantized to INT4
:param subset_size: Number of data samples to calculate activation statistics used for assigning different
quantization precision.
:param awq: determines whether to use or not modified AWQ algorithm.
:param scale_estimation: determines whether to use or not scale estimation for 4 bit layers.
:param gptq: determines whether to use or not GPTQ algorithm.
:param lora_correction: determines whether to use or not LoRA Correction algorithm.
:param sensitivity_metric: The sensitivity metric for assigning quantization precision to layers. In order to
preserve the accuracy of the model, the more sensitive layers receives a higher precision.
:param compression_format: Describes the format in which the model is saved after weight compression.
:param advanced_parameters: advanced parameters for algorithms in compression pipeline.
"""
self._quantizer = quantizer
wc_config = quantizer.get_weight_compression_config()

mode = wc_config.get("mode", CompressWeightsMode.INT8_ASYM)

self._algo = OriginalWeightCompression(
mode=CompressWeightsMode(mode),
ratio=ratio,
group_size=wc_config.get("group_size", None),
ignored_scope=None,
all_layers=wc_config.get("all_layers", None),
sensitivity_metric=sensitivity_metric,
awq=awq,
subset_size=subset_size,
scale_estimation=scale_estimation,
gptq=gptq,
lora_correction=lora_correction,
backup_mode=wc_config.get("backup_mode", None),
compression_format=compression_format,
advanced_parameters=advanced_parameters,
)

def available_backends(self) -> list[BackendType]:
return [BackendType.TORCH_FX]

def apply(
self,
model: torch.fx.GraphModule,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> torch.fx.GraphModule:
self._algo.set_backend_entity(model)

all_weight_params, ratio_defining_params, skipped_weight_params = (
self._quantizer.get_weight_compression_parameters(model, graph)
)
# Collect statistics for the weights compression
statistics, statistic_points = self._algo.collect_statistics_and_statistic_points(
model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params
)
# Apply Mixed precision algorithm to ratio defining parameters
self._algo.apply_mixed_precision(ratio_defining_params, model, graph, statistic_points)
self._algo.validate_group_size(ratio_defining_params)

# Print statistics
nncf_logger.info(
self._algo.get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params)
)

# Filter all_weight_params by excluding nodes that should remain in their original floating-point precision
all_weight_params = [w_params for w_params in all_weight_params if w_params.compression_config is not None]
return self._algo.apply_with_parameters(
model,
graph,
dataset,
statistics,
all_weight_params,
)

def get_statistic_points(
self,
model: torch.fx.GraphModule,
graph: NNCFGraph,
nodes_and_port_ids: Iterable[tuple[NNCFNode, int]],
) -> StatisticPointsContainer:
"""
Returns statistic points, for which StatisticsCollector should collect statistics.

:param model: Model for statistics collection.
:param graph: Model graph.
:param nodes_and_port_ids: Nodes and port ids for which statistics should be collected.
:return: Statistic points, for which StatisticsCollector should collect statistics.
"""
return self._algo.get_statistic_points(model, graph, nodes_and_port_ids)
34 changes: 33 additions & 1 deletion src/nncf/experimental/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

from abc import ABC
from abc import abstractmethod
from typing import TypeVar
from typing import Any, TypeVar

from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters

TModel = TypeVar("TModel")

Expand Down Expand Up @@ -43,3 +44,34 @@ def get_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> Single
:param nncf_graph: NNCFGraph instance.
:return: SingleConfigQuantizerSetup for the given model.
"""

@abstractmethod
def get_weight_compression_parameters(
self,
model: TModel,
nncf_graph: NNCFGraph,
) -> tuple[
list[WeightCompressionParameters],
list[WeightCompressionParameters],
list[WeightCompressionParameters],
]:
"""
Obtains the weight compression parameters from the quantizer which can be used to determine
weights to compress, weights to skip, weights to consider for mixed precision assignment.

:param model: Backend-specific model.
:param nncf_graph: NNCFGraph instance.
:return: Tuple of (all_weight_params, ratio_defining_params, skipped_weight_params) where:
1. all_weight_params: all compressible weight parameters in the model
2. ratio_defining_params: subset of weights used for mixed precision assignment
3. skipped_weight_params: weights that should be excluded from compression
"""

@abstractmethod
def get_weight_compression_config(self) -> dict[str, Any]:
"""
Returns the weight compression configuration as a dictionary.

:return: Dictionary containing compression configuration parameters obtained
from the quantizer.
"""
1 change: 1 addition & 0 deletions src/nncf/experimental/torch/fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nncf.experimental.torch.fx.quantization.quantize_pt2e import compress_pt2e as compress_pt2e
from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e as quantize_pt2e
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer as OpenVINOQuantizer
72 changes: 72 additions & 0 deletions src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
from torch.fx.passes.infra.pass_manager import PassManager

import nncf
from nncf import AdvancedCompressionParameters
from nncf import Dataset
from nncf import SensitivityMetric
from nncf.common.factory import NNCFGraphFactory
from nncf.common.logging import nncf_logger
from nncf.common.utils.api_marker import api
from nncf.experimental.quantization.algorithms.post_training.algorithm import ExperimentalPostTrainingQuantization
from nncf.experimental.quantization.algorithms.weight_compression.algorithm import WeightsCompression
from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.experimental.torch.fx.quantization.quantizer.openvino_adapter import OpenVINOQuantizerAdapter
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer
Expand Down Expand Up @@ -157,3 +160,72 @@ def _quant_node_constraint(n: torch.fx.Node) -> bool:
related to quantization
"""
return n.op == "call_function" and n.target in QUANTIZE_NODE_TARGETS


@api(canonical_alias="nncf.experimental.torch.fx.compress_pt2e")
def compress_pt2e(
model: torch.fx.GraphModule,
quantizer: Quantizer,
*,
dataset: Optional[nncf.Dataset] = None,
awq: bool = False,
scale_estimation: bool = False,
gptq: bool = False,
lora_correction: bool = False,
subset_size: int = 128,
ratio: int = 1,
sensitivity_metric: Optional[SensitivityMetric] = None,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
) -> torch.fx.GraphModule:
"""
Applies Weight Compression to the torch.fx.GraphModule model using provided torch.ao quantizer.

:param model: A torch.fx.GraphModule instance to be quantized.
:param quantizer: Torch ao quantizer to annotate nodes in the graph with quantization setups
to convey the desired way of quantization.
:param dataset: A representative dataset for the calibration process.
:param awq: Determines whether to use or not the modified AWQ algorithm.
:param scale_estimation: Determines whether to use or not scale estimation for 4-bit layers.
:param gptq: Determines whether to use or not GPTQ algorithm.
:param lora_correction: Determines whether to use or not LoRA Correction algorithm.
:param subset_size: Number of data samples to calculate activation statistics used for assigning different
quantization precision.
:param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4
and the rest to INT8_ASYM).
:param sensitivity_metric: The sensitivity metric for assigning quantization precision to layers. In order to
preserve the accuracy of the model, the more sensitive layers receive a higher precision.
:param advanced_parameters: Advanced parameters for algorithms in the compression pipeline.
"""
if isinstance(quantizer, OpenVINOQuantizer) or hasattr(quantizer, "get_nncf_weight_compression_parameters"):
quantizer = OpenVINOQuantizerAdapter(quantizer)
compression_format = nncf.CompressionFormat.DQ
else:
# TODO Support Third party quantizers here.
msg = "Only OpenVINO Quantizer is supported currently."
raise nncf.InternalError(msg)

sensitivity_metric = (
(SensitivityMetric.WEIGHT_QUANTIZATION_ERROR if dataset is None else SensitivityMetric.MAX_ACTIVATION_VARIANCE)
if sensitivity_metric is None
else sensitivity_metric
)

quantization_algorithm = WeightsCompression(
quantizer=quantizer,
subset_size=subset_size,
compression_format=compression_format,
ratio=ratio,
awq=awq,
scale_estimation=scale_estimation,
gptq=gptq,
lora_correction=lora_correction,
sensitivity_metric=sensitivity_metric,
advanced_parameters=advanced_parameters,
)

# Here the model is annotated
transformed_model = quantizer.transform_prior_quantization(model)
nncf_graph = NNCFGraphFactory.create(transformed_model)
quantized_model = quantization_algorithm.apply(transformed_model, nncf_graph, dataset=dataset)
quantized_model = torch.fx.GraphModule(quantized_model, graph=quantized_model.graph)
return quantized_model
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch.fx

from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.experimental.quantization.quantizer import Quantizer
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters


class OpenVINOQuantizerAdapter(Quantizer):
Expand All @@ -30,3 +33,17 @@ def transform_prior_quantization(self, model: torch.fx.GraphModule) -> torch.fx.

def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
return self._quantizer.get_nncf_quantization_setup(model, nncf_graph)

def get_weight_compression_parameters(
self,
model: torch.fx.GraphModule,
nncf_graph: NNCFGraph,
) -> tuple[
list[WeightCompressionParameters],
list[WeightCompressionParameters],
list[WeightCompressionParameters],
]:
return self._quantizer.get_nncf_weight_compression_parameters(model, nncf_graph)

def get_weight_compression_config(self) -> dict[str, Any]:
return self._quantizer.weight_compression_configuration
Loading
Loading