-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
System Info
Peft: https://github.com/BenjaminBossan/peft.git@d399d19706caa63dcaebd566c8017f3d59d33246#egg=peft
(still using the version I tested for @BenjaminBossan , unsure if this new bug has been solved in more recent version)
Who can help?
Reproduction
Context: I was testing quantization and got an error that the module_to_save require.grad was True even if I set my model in inference mode and used inference_mode=False when using set_adapter:
Adapters must contain some modules_to_save to reproduce the bug
import bitsandbytes as bnb
import torch.nn as nn
class SafeInt8Linear(bnb.nn.Linear8bitLt):
"""
Wraps Linear8bitLt to fix data type mismatches.
Ensures input/output are FP16, which prevents Flash Attention crashes.
"""
def forward(self, x):
# 1. Force Input to FP16 (Optimizes BnB kernel speed)
if x.dtype != torch.float16:
x = x.to(torch.float16)
# 2. Run INT8 Matmul
output = super().forward(x)
# 3. Force Output to FP16 (Crucial for Flash Attention)
if output.dtype != torch.float16:
output = output.to(torch.float16)
return output
def convert_to_int8(model):
"""Convert ALL Linear layers > 256 dim to INT8 SafeWrapper."""
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Skip very small layers (overhead outweighs speedup)
if module.in_features < 256 or module.out_features < 256:
continue
# Create the Safe INT8 version
int8_layer = SafeInt8Linear(
module.in_features,
module.out_features,
bias=module.bias is not None,
has_fp16_weights=False,
threshold=6.0
)
# Copy weights (BnB requires CPU tensors for initialization)
int8_layer.weight = bnb.nn.Int8Params(
module.weight.data.cpu(),
requires_grad=False,
has_fp16_weights=False
)
if module.bias is not None:
int8_layer.bias = nn.Parameter(module.bias.data)
# Replace in model
parent = model
name_list = name.split('.')
for n in name_list[:-1]:
parent = getattr(parent, n)
setattr(parent, name_list[-1], int8_layer)
return model
model = PeftModel.from_pretrained(
some_model,
adapter_path,
adapter_name="my_adapter1",
is_trainable=False
)
model.load_adapter(adapter_path2, is_trainable=False,
adapter_name="my_adapter2")
model = convert_to_int8(model)
model.to(device)
model.eval()
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
model.set_adapter(my_adapter1, inference_mode=False)
_ = model(batch1)
model.set_adapter(my_adapter2, inference_mode=False)
_ = model(batch2)
Expected behavior
modules_to_save should not have gradient to true in inference mode.
From the look of it, set_adapter still set module_to_save gradient to True even if inference_mode = False.
Suggested Fix (working on my end with this): Add if not inference_mode: before module.enable_adapters()
It worked for modules_to_save where the code only set the gradient but unsure if there is any other Auxiliary module that could get impacted.
module.py
def _set_adapter(model, adapter_name: str | list[str], inference_mode: bool = False):
for module in model.modules():
if isinstance(module, AuxiliaryTrainingWrapper):
# only check the adapter_name if we actually encounter a AuxiliaryTrainingWrapper, otherwise we don't care
adapter_name_to_set = module.check_set_adapter(adapter_name)
# if the adapter is found in this module, set it as the active adapter, else disable the adapters of this
# module
if adapter_name_to_set in module._adapters:
if not inference_mode:
module.enable_adapters(True)
module.set_adapter(adapter_name_to_set, inference_mode=inference_mode)
else:
if not inference_mode:
module.enable_adapters(False)
module.set_adapter([], inference_mode=inference_mode)