Skip to content

Commit 3bdeb0f

Browse files
committed
Removed CB regard kw args for functionin non CB models
Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
1 parent d7e4fa1 commit 3bdeb0f

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ class _QEffAutoModelForImageTextToTextDualQPC:
858858
def __init__(
859859
self,
860860
model: nn.Module,
861-
continuous_batching,
861+
continuous_batching: bool = False,
862862
**kwargs,
863863
):
864864
"""
@@ -982,8 +982,15 @@ def export(
982982
List[str]
983983
A list containing the paths to the generated ONNX graph files for both components.
984984
"""
985-
inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching)
986-
dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True, continuous_batching=self.continuous_batching)
985+
# TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed.
986+
try:
987+
inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching)
988+
dynamic_axes = self.model.get_onnx_dynamic_axes(
989+
kv_offload=True, continuous_batching=self.continuous_batching
990+
)
991+
except TypeError:
992+
inputs = self.model.get_dummy_inputs(kv_offload=True)
993+
dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True)
987994
output_names = self.model.get_output_names(kv_offload=True)
988995

989996
self.vision_model.export(
@@ -1124,6 +1131,11 @@ def compile(
11241131
):
11251132
self.export()
11261133

1134+
# TODO this hould be removed once the continous batching is supported for all the models.
1135+
compiler_options.pop("continuous_batching", None)
1136+
compiler_options.pop("kv_cache_batch_size", None)
1137+
compiler_options.pop("full_batch_size", None)
1138+
11271139
if not skip_vision:
11281140
self.vision_model._compile(
11291141
compile_dir=compile_dir,

0 commit comments

Comments
 (0)