From d8651801d9fa4ebe2337bde325245907106f1e93 Mon Sep 17 00:00:00 2001 From: linzhihao Date: Fri, 31 May 2024 15:11:39 +0800 Subject: [PATCH] support plora for llava and internvl 1.5 --- xtuner/model/internvl_1_5_llava.py | 17 ++++++++- xtuner/model/llava.py | 17 ++++++++- xtuner/model/modules/plora.py | 55 ++++++++++++++++++++++++++++++ xtuner/model/utils.py | 42 ++++++++++++++++++++--- 4 files changed, 125 insertions(+), 6 deletions(-) create mode 100644 xtuner/model/modules/plora.py diff --git a/xtuner/model/internvl_1_5_llava.py b/xtuner/model/internvl_1_5_llava.py index 4cc5a1b15..436f5cdb9 100644 --- a/xtuner/model/internvl_1_5_llava.py +++ b/xtuner/model/internvl_1_5_llava.py @@ -5,6 +5,7 @@ from xtuner.registry import BUILDER from .modules import ProjectorConfig, ProjectorModel, dispatch_modules +from .modules.plora import add_plora from .utils import (LoadWoInit, guess_load_checkpoint, make_inputs_require_grad, prepare_inputs_labels_for_multimodal) @@ -30,6 +31,10 @@ def __init__(self, llm, max_position_embeddings=None, image_processor=None, tokenizer=None, + use_plora=False, + plora_r=256, + plora_alpha=256, + plora_dropout=0.05, template=None, use_lldr=False, # LearningRateDecayOptimWrapperConstructor merge_type='pixel_shuffle', # or pixel_shuffle @@ -58,6 +63,13 @@ def __init__(self, llm, self.llm.config.use_cache = False dispatch_modules(self.llm) + self.use_plora = use_plora + if self.use_plora: + add_plora(self.llm, + lora_r=plora_r, + lora_alpha=plora_alpha, + lora_dropout=plora_dropout) + self.custom_mlp = custom_mlp if custom_mlp is True: self.mlp1 = nn.Sequential( @@ -162,7 +174,10 @@ def _prepare_data_for_llm(self, data): if 'pixel_values' in data: new_image_feature = self.__preprocess_for_pixel_values(data) data['pixel_values'] = new_image_feature - data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) + data = prepare_inputs_labels_for_multimodal( + llm=self.llm, + save_im_mask=self.use_plora, + **data) return data def __preprocess_for_pixel_values(self, data): diff --git a/xtuner/model/llava.py b/xtuner/model/llava.py index 1a4003dc3..203db4e97 100644 --- a/xtuner/model/llava.py +++ b/xtuner/model/llava.py @@ -12,6 +12,7 @@ from xtuner.registry import BUILDER from .modules import ProjectorConfig, ProjectorModel, dispatch_modules from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 +from .modules.plora import add_plora from .utils import (LoadWoInit, find_all_linear_names, get_peft_model_state_dict, guess_load_checkpoint, make_inputs_require_grad, @@ -44,6 +45,10 @@ def __init__(self, max_position_embeddings=None, image_processor=None, tokenizer=None, + use_plora=False, + plora_r=256, + plora_alpha=256, + plora_dropout=0.05, template=None, use_lldr=False, # LearningRateDecayOptimWrapperConstructor ): @@ -69,6 +74,13 @@ def __init__(self, self.llm.config.use_cache = False dispatch_modules(self.llm) + self.use_plora = use_plora + if self.use_plora: + add_plora(self.llm, + lora_r=plora_r, + lora_alpha=plora_alpha, + lora_dropout=plora_dropout) + assert int(token_merge_ratio ** 0.5) ** 2 == token_merge_ratio, \ '`token_merge_ratio` must be a square number.' self.token_merge_ratio = int(token_merge_ratio) @@ -355,7 +367,10 @@ def _prepare_data_for_llm(self, data): pixel_values = self.projector(visual_outputs) data['pixel_values'] = pixel_values - data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) + data = prepare_inputs_labels_for_multimodal( + llm=self.llm, + save_im_mask=self.use_plora, + **data) return data def forward(self, data, data_samples=None, mode='loss'): diff --git a/xtuner/model/modules/plora.py b/xtuner/model/modules/plora.py new file mode 100644 index 000000000..60deff493 --- /dev/null +++ b/xtuner/model/modules/plora.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch.nn as nn +from mmengine import MessageHub, print_log +from mmengine.dist import get_rank + + +def add_plora_to_linear(module, lora_r=256, lora_alpha=256, lora_dropout=0.05): + device = module.weight.device + dtype = module.weight.dtype + Plora_A = nn.Linear( + module.in_features, lora_r, bias=False, device=device, dtype=dtype) + Plora_B = nn.Linear( + lora_r, module.out_features, bias=False, device=device, dtype=dtype) + nn.init.kaiming_uniform_(Plora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(Plora_B.weight) + + lora_dropout = nn.Dropout(p=lora_dropout) + lora_scaling = lora_alpha / lora_r + + module.add_module('Plora_A', Plora_A) + module.add_module('Plora_B', Plora_B) + module.add_module('lora_dropout', lora_dropout) + setattr(module, 'lora_scaling', lora_scaling) + + def forward_plora(self, x): + res = self.forward_original(x) + rank = get_rank() + message_hub = MessageHub.get_instance('im_mask_info') + im_mask = message_hub.get_info(f'im_mask_{rank}') + # if rank in [0, 1]: + # print('*****', flush=True) + # print(rank, flush=True) + # print(im_mask.sum(-1), flush=True) + # print('*****', flush=True) + if im_mask is not None and x.shape[1] == im_mask.shape[-1]: + part_x = x[im_mask] + res[im_mask] += self.Plora_B(self.Plora_A(self.lora_dropout(part_x))) * self.lora_scaling + return res + + module.forward_original = module.forward + module.forward = forward_plora.__get__(module, nn.Linear) + + +def add_plora(model, lora_r=256, lora_alpha=256, lora_dropout=0.05): + for name, module in model.named_modules(): + if (isinstance(module, nn.Linear) and 'Plora' not in name + and 'lm_head' not in name and 'output_layer' not in name): + print_log(f'Add PLoRA to {name}', 'current') + add_plora_to_linear( + module, + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout) diff --git a/xtuner/model/utils.py b/xtuner/model/utils.py index 2553a369b..ebc65a9b7 100644 --- a/xtuner/model/utils.py +++ b/xtuner/model/utils.py @@ -3,7 +3,8 @@ from typing import List, Optional import torch -from mmengine import print_log +from mmengine.dist import get_rank +from mmengine import print_log, MessageHub from mmengine.utils.misc import get_object_from_string from peft import PeftType from torch import nn @@ -129,6 +130,7 @@ def get_peft_model_state_dict(model, state_dict=None, adapter_name='default'): # Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99 # noqa: E501 def prepare_inputs_labels_for_multimodal( llm: PreTrainedModel, + save_im_mask=False, input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -137,7 +139,7 @@ def prepare_inputs_labels_for_multimodal( pixel_values: Optional[torch.FloatTensor] = None, **kwargs): if pixel_values is None: - return { + rets = { 'input_ids': input_ids, 'position_ids': position_ids, 'attention_mask': attention_mask, @@ -145,6 +147,9 @@ def prepare_inputs_labels_for_multimodal( 'inputs_embeds': None, 'labels': labels } + if save_im_mask: + rets['im_mask'] = torch.zeros_like(input_ids, dtype=torch.bool) + return rets _labels = labels _position_ids = position_ids @@ -170,6 +175,7 @@ def prepare_inputs_labels_for_multimodal( ] new_inputs_embeds = [] + new_im_mask = [] new_labels = [] cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): @@ -180,6 +186,11 @@ def prepare_inputs_labels_for_multimodal( cur_inputs_embeds = torch.cat( [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0) new_inputs_embeds.append(cur_inputs_embeds) + new_im_mask.append( + torch.full((cur_inputs_embeds.shape[0], ), + False, + dtype=torch.bool, + device=cur_inputs_embeds.device)) new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue @@ -203,15 +214,26 @@ def prepare_inputs_labels_for_multimodal( cur_inputs_embeds_no_im = torch.split( cur_inputs_embeds, split_sizes, dim=0) cur_new_inputs_embeds = [] + cur_new_im_mask = [] cur_new_labels = [] for i in range(num_images + 1): cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i]) + cur_new_im_mask.append( + torch.full((cur_inputs_embeds_no_im[i].shape[0], ), + False, + dtype=torch.bool, + device=cur_inputs_embeds_no_im[i].device)) cur_new_labels.append(cur_labels_noim[i]) if i < num_images: cur_pixel_values = pixel_values[cur_image_idx] cur_image_idx += 1 cur_new_inputs_embeds.append(cur_pixel_values) + cur_new_im_mask.append( + torch.full((cur_pixel_values.shape[0], ), + True, + dtype=torch.bool, + device=cur_pixel_values.device)) cur_new_labels.append( torch.full((cur_pixel_values.shape[0], ), IGNORE_INDEX, @@ -219,9 +241,11 @@ def prepare_inputs_labels_for_multimodal( dtype=cur_labels.dtype)) cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds) + cur_new_im_mask = torch.cat(cur_new_im_mask) cur_new_labels = torch.cat(cur_new_labels) new_inputs_embeds.append(cur_new_inputs_embeds) + new_im_mask.append(cur_new_im_mask) new_labels.append(cur_new_labels) # Combine them @@ -239,9 +263,12 @@ def prepare_inputs_labels_for_multimodal( position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + im_mask = torch.zeros((batch_size, max_len), + dtype=torch.bool, + device=position_ids.device) - for i, (cur_new_embed, - cur_new_labels) in enumerate(zip(new_inputs_embeds, new_labels)): + for i, (cur_new_embed, cur_new_labels, cur_new_im_mask) in enumerate(zip( + new_inputs_embeds, new_labels, new_im_mask)): cur_len = cur_new_embed.shape[0] new_inputs_embeds_padded.append( torch.cat((cur_new_embed, @@ -257,6 +284,7 @@ def prepare_inputs_labels_for_multimodal( cur_len, dtype=position_ids.dtype, device=position_ids.device) + im_mask[i, :cur_len] = cur_new_im_mask new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) @@ -273,6 +301,11 @@ def prepare_inputs_labels_for_multimodal( if _position_ids is None: position_ids = None + if save_im_mask: + rank = get_rank() + message_hub = MessageHub.get_instance('im_mask_info') + message_hub.update_info(f'im_mask_{rank}', im_mask) + return { 'input_ids': None, 'position_ids': position_ids, @@ -314,6 +347,7 @@ def guess_load_checkpoint(pth_model): # from https://github.com/bfshi/scaling_on_scales import math + import torch.nn.functional as F from einops import rearrange