Skip to content

Inference mode with Module_to_save LoRA #2928

@NguyenRichard

Description

@NguyenRichard

System Info

Peft: https://github.com/BenjaminBossan/peft.git@d399d19706caa63dcaebd566c8017f3d59d33246#egg=peft
(still using the version I tested for @BenjaminBossan , unsure if this new bug has been solved in more recent version)

Who can help?

@BenjaminBossan

Reproduction

Context: I was testing quantization and got an error that the module_to_save require.grad was True even if I set my model in inference mode and used inference_mode=False when using set_adapter:

Adapters must contain some modules_to_save to reproduce the bug

import bitsandbytes as bnb
import torch.nn as nn


class SafeInt8Linear(bnb.nn.Linear8bitLt):
    """
    Wraps Linear8bitLt to fix data type mismatches.
    Ensures input/output are FP16, which prevents Flash Attention crashes.
    """

    def forward(self, x):
        # 1. Force Input to FP16 (Optimizes BnB kernel speed)
        if x.dtype != torch.float16:
            x = x.to(torch.float16)

        # 2. Run INT8 Matmul
        output = super().forward(x)

        # 3. Force Output to FP16 (Crucial for Flash Attention)
        if output.dtype != torch.float16:
            output = output.to(torch.float16)

        return output


def convert_to_int8(model):
    """Convert ALL Linear layers > 256 dim to INT8 SafeWrapper."""

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Skip very small layers (overhead outweighs speedup)
            if module.in_features < 256 or module.out_features < 256:
                continue

            # Create the Safe INT8 version
            int8_layer = SafeInt8Linear(
                module.in_features,
                module.out_features,
                bias=module.bias is not None,
                has_fp16_weights=False,
                threshold=6.0
            )

            # Copy weights (BnB requires CPU tensors for initialization)
            int8_layer.weight = bnb.nn.Int8Params(
                module.weight.data.cpu(),
                requires_grad=False,
                has_fp16_weights=False
            )

            if module.bias is not None:
                int8_layer.bias = nn.Parameter(module.bias.data)

            # Replace in model
            parent = model
            name_list = name.split('.')
            for n in name_list[:-1]:
                parent = getattr(parent, n)

            setattr(parent, name_list[-1], int8_layer)

    return model


model = PeftModel.from_pretrained(
                    some_model,
                    adapter_path,
                    adapter_name="my_adapter1",
                    is_trainable=False
                )
model.load_adapter(adapter_path2, is_trainable=False,
                                                       adapter_name="my_adapter2")
model = convert_to_int8(model)
model.to(device)
model.eval()


with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
   model.set_adapter(my_adapter1, inference_mode=False)
    _ = model(batch1)
   model.set_adapter(my_adapter2, inference_mode=False)
    _ = model(batch2)

Expected behavior

modules_to_save should not have gradient to true in inference mode.

From the look of it, set_adapter still set module_to_save gradient to True even if inference_mode = False.

Suggested Fix (working on my end with this): Add if not inference_mode: before module.enable_adapters()
It worked for modules_to_save where the code only set the gradient but unsure if there is any other Auxiliary module that could get impacted.
module.py

def _set_adapter(model, adapter_name: str | list[str], inference_mode: bool = False):
    for module in model.modules():
        if isinstance(module, AuxiliaryTrainingWrapper):
            # only check the adapter_name if we actually encounter a AuxiliaryTrainingWrapper, otherwise we don't care
            adapter_name_to_set = module.check_set_adapter(adapter_name)

            # if the adapter is found in this module, set it as the active adapter, else disable the adapters of this
            # module
            if adapter_name_to_set in module._adapters:
                if not inference_mode:
                    module.enable_adapters(True)
                module.set_adapter(adapter_name_to_set, inference_mode=inference_mode)
            else:
                if not inference_mode:
                    module.enable_adapters(False)
                module.set_adapter([], inference_mode=inference_mode)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions