Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion xtuner/model/internvl_1_5_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from xtuner.registry import BUILDER
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
from .modules.plora import add_plora
from .utils import (LoadWoInit, guess_load_checkpoint,
make_inputs_require_grad,
prepare_inputs_labels_for_multimodal)
Expand All @@ -30,6 +31,10 @@ def __init__(self, llm,
max_position_embeddings=None,
image_processor=None,
tokenizer=None,
use_plora=False,
plora_r=256,
plora_alpha=256,
plora_dropout=0.05,
template=None,
use_lldr=False, # LearningRateDecayOptimWrapperConstructor
merge_type='pixel_shuffle', # or pixel_shuffle
Expand Down Expand Up @@ -58,6 +63,13 @@ def __init__(self, llm,
self.llm.config.use_cache = False
dispatch_modules(self.llm)

self.use_plora = use_plora
if self.use_plora:
add_plora(self.llm,
lora_r=plora_r,
lora_alpha=plora_alpha,
lora_dropout=plora_dropout)

self.custom_mlp = custom_mlp
if custom_mlp is True:
self.mlp1 = nn.Sequential(
Expand Down Expand Up @@ -162,7 +174,10 @@ def _prepare_data_for_llm(self, data):
if 'pixel_values' in data:
new_image_feature = self.__preprocess_for_pixel_values(data)
data['pixel_values'] = new_image_feature
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
data = prepare_inputs_labels_for_multimodal(
llm=self.llm,
save_im_mask=self.use_plora,
**data)
return data

def __preprocess_for_pixel_values(self, data):
Expand Down
17 changes: 16 additions & 1 deletion xtuner/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xtuner.registry import BUILDER
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
from .modules.plora import add_plora
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, guess_load_checkpoint,
make_inputs_require_grad,
Expand Down Expand Up @@ -44,6 +45,10 @@ def __init__(self,
max_position_embeddings=None,
image_processor=None,
tokenizer=None,
use_plora=False,
plora_r=256,
plora_alpha=256,
plora_dropout=0.05,
template=None,
use_lldr=False, # LearningRateDecayOptimWrapperConstructor
):
Expand All @@ -69,6 +74,13 @@ def __init__(self,
self.llm.config.use_cache = False
dispatch_modules(self.llm)

self.use_plora = use_plora
if self.use_plora:
add_plora(self.llm,
lora_r=plora_r,
lora_alpha=plora_alpha,
lora_dropout=plora_dropout)

assert int(token_merge_ratio ** 0.5) ** 2 == token_merge_ratio, \
'`token_merge_ratio` must be a square number.'
self.token_merge_ratio = int(token_merge_ratio)
Expand Down Expand Up @@ -355,7 +367,10 @@ def _prepare_data_for_llm(self, data):
pixel_values = self.projector(visual_outputs)

data['pixel_values'] = pixel_values
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
data = prepare_inputs_labels_for_multimodal(
llm=self.llm,
save_im_mask=self.use_plora,
**data)
return data

def forward(self, data, data_samples=None, mode='loss'):
Expand Down
55 changes: 55 additions & 0 deletions xtuner/model/modules/plora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch.nn as nn
from mmengine import MessageHub, print_log
from mmengine.dist import get_rank


def add_plora_to_linear(module, lora_r=256, lora_alpha=256, lora_dropout=0.05):
device = module.weight.device
dtype = module.weight.dtype
Plora_A = nn.Linear(
module.in_features, lora_r, bias=False, device=device, dtype=dtype)
Plora_B = nn.Linear(
lora_r, module.out_features, bias=False, device=device, dtype=dtype)
nn.init.kaiming_uniform_(Plora_A.weight, a=math.sqrt(5))
nn.init.zeros_(Plora_B.weight)

lora_dropout = nn.Dropout(p=lora_dropout)
lora_scaling = lora_alpha / lora_r

module.add_module('Plora_A', Plora_A)
module.add_module('Plora_B', Plora_B)
module.add_module('lora_dropout', lora_dropout)
setattr(module, 'lora_scaling', lora_scaling)

def forward_plora(self, x):
res = self.forward_original(x)
rank = get_rank()
message_hub = MessageHub.get_instance('im_mask_info')
im_mask = message_hub.get_info(f'im_mask_{rank}')
# if rank in [0, 1]:
# print('*****', flush=True)
# print(rank, flush=True)
# print(im_mask.sum(-1), flush=True)
# print('*****', flush=True)
if im_mask is not None and x.shape[1] == im_mask.shape[-1]:
part_x = x[im_mask]
res[im_mask] += self.Plora_B(self.Plora_A(self.lora_dropout(part_x))) * self.lora_scaling
return res

module.forward_original = module.forward
module.forward = forward_plora.__get__(module, nn.Linear)


def add_plora(model, lora_r=256, lora_alpha=256, lora_dropout=0.05):
for name, module in model.named_modules():
if (isinstance(module, nn.Linear) and 'Plora' not in name
and 'lm_head' not in name and 'output_layer' not in name):
print_log(f'Add PLoRA to {name}', 'current')
add_plora_to_linear(
module,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout)
42 changes: 38 additions & 4 deletions xtuner/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import List, Optional

import torch
from mmengine import print_log
from mmengine.dist import get_rank
from mmengine import print_log, MessageHub
from mmengine.utils.misc import get_object_from_string
from peft import PeftType
from torch import nn
Expand Down Expand Up @@ -129,6 +130,7 @@ def get_peft_model_state_dict(model, state_dict=None, adapter_name='default'):
# Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99 # noqa: E501
def prepare_inputs_labels_for_multimodal(
llm: PreTrainedModel,
save_im_mask=False,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
Expand All @@ -137,14 +139,17 @@ def prepare_inputs_labels_for_multimodal(
pixel_values: Optional[torch.FloatTensor] = None,
**kwargs):
if pixel_values is None:
return {
rets = {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'past_key_values': past_key_values,
'inputs_embeds': None,
'labels': labels
}
if save_im_mask:
rets['im_mask'] = torch.zeros_like(input_ids, dtype=torch.bool)
return rets

_labels = labels
_position_ids = position_ids
Expand All @@ -170,6 +175,7 @@ def prepare_inputs_labels_for_multimodal(
]

new_inputs_embeds = []
new_im_mask = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
Expand All @@ -180,6 +186,11 @@ def prepare_inputs_labels_for_multimodal(
cur_inputs_embeds = torch.cat(
[cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0)
new_inputs_embeds.append(cur_inputs_embeds)
new_im_mask.append(
torch.full((cur_inputs_embeds.shape[0], ),
False,
dtype=torch.bool,
device=cur_inputs_embeds.device))
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
Expand All @@ -203,25 +214,38 @@ def prepare_inputs_labels_for_multimodal(
cur_inputs_embeds_no_im = torch.split(
cur_inputs_embeds, split_sizes, dim=0)
cur_new_inputs_embeds = []
cur_new_im_mask = []
cur_new_labels = []

for i in range(num_images + 1):
cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i])
cur_new_im_mask.append(
torch.full((cur_inputs_embeds_no_im[i].shape[0], ),
False,
dtype=torch.bool,
device=cur_inputs_embeds_no_im[i].device))
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_pixel_values = pixel_values[cur_image_idx]
cur_image_idx += 1
cur_new_inputs_embeds.append(cur_pixel_values)
cur_new_im_mask.append(
torch.full((cur_pixel_values.shape[0], ),
True,
dtype=torch.bool,
device=cur_pixel_values.device))
cur_new_labels.append(
torch.full((cur_pixel_values.shape[0], ),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype))

cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds)
cur_new_im_mask = torch.cat(cur_new_im_mask)
cur_new_labels = torch.cat(cur_new_labels)

new_inputs_embeds.append(cur_new_inputs_embeds)
new_im_mask.append(cur_new_im_mask)
new_labels.append(cur_new_labels)

# Combine them
Expand All @@ -239,9 +263,12 @@ def prepare_inputs_labels_for_multimodal(
position_ids = torch.zeros((batch_size, max_len),
dtype=position_ids.dtype,
device=position_ids.device)
im_mask = torch.zeros((batch_size, max_len),
dtype=torch.bool,
device=position_ids.device)

for i, (cur_new_embed,
cur_new_labels) in enumerate(zip(new_inputs_embeds, new_labels)):
for i, (cur_new_embed, cur_new_labels, cur_new_im_mask) in enumerate(zip(
new_inputs_embeds, new_labels, new_im_mask)):
cur_len = cur_new_embed.shape[0]
new_inputs_embeds_padded.append(
torch.cat((cur_new_embed,
Expand All @@ -257,6 +284,7 @@ def prepare_inputs_labels_for_multimodal(
cur_len,
dtype=position_ids.dtype,
device=position_ids.device)
im_mask[i, :cur_len] = cur_new_im_mask

new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)

Expand All @@ -273,6 +301,11 @@ def prepare_inputs_labels_for_multimodal(
if _position_ids is None:
position_ids = None

if save_im_mask:
rank = get_rank()
message_hub = MessageHub.get_instance('im_mask_info')
message_hub.update_info(f'im_mask_{rank}', im_mask)

return {
'input_ids': None,
'position_ids': position_ids,
Expand Down Expand Up @@ -314,6 +347,7 @@ def guess_load_checkpoint(pth_model):
# from https://github.com/bfshi/scaling_on_scales

import math

import torch.nn.functional as F
from einops import rearrange

Expand Down