Skip to content

Commit ec1ea0a

Browse files
pppppMHIT-cwhLZHgrla
authored
[Feature]Support Mircosoft Phi3 4K&128K Instruct Models (InternLM#603)
* support phi3 * dispatch sft * rename configs * add phi3 llava configs * dispatch llava * fix phi3 dispatch (#3) * remove readme; fix ckpt name * remove unused file * add comma * fix typo * rename * set dataloader_num_workers = 0 --------- Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: linzhihao <linzhihao@pjlab.org.cn>
1 parent 86cd930 commit ec1ea0a

12 files changed

+1957
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
3+
LoggerHook, ParamSchedulerHook)
4+
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
5+
from torch.optim import AdamW
6+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
7+
CLIPImageProcessor, CLIPVisionModel)
8+
9+
from xtuner.dataset import LLaVADataset
10+
from xtuner.dataset.collate_fns import default_collate_fn
11+
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
12+
from xtuner.dataset.samplers import LengthGroupedSampler
13+
from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
14+
from xtuner.engine.runner import TrainLoop
15+
from xtuner.model import LLaVAModel
16+
from xtuner.utils import PROMPT_TEMPLATE
17+
18+
#######################################################################
19+
# PART 1 Settings #
20+
#######################################################################
21+
# Model
22+
llm_name_or_path = 'microsoft/Phi-3-mini-4k-instruct'
23+
visual_encoder_name_or_path = 'openai/clip-vit-large-patch14-336'
24+
# Specify the pretrained pth
25+
pretrained_pth = './work_dirs/llava_phi3_mini_4k_instruct_clip_vit_large_p14_336_e1_gpu8_pretrain/iter_2181.pth' # noqa: E501
26+
27+
# Data
28+
data_root = './data/llava_data/'
29+
data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
30+
image_folder = data_root + 'llava_images'
31+
prompt_template = PROMPT_TEMPLATE.phi3_chat
32+
max_length = int(2048 - (336 / 14)**2)
33+
34+
# Scheduler & Optimizer
35+
batch_size = 8 # per_device
36+
accumulative_counts = 2
37+
dataloader_num_workers = 4
38+
max_epochs = 1
39+
optim_type = AdamW
40+
lr = 2e-5
41+
betas = (0.9, 0.999)
42+
weight_decay = 0
43+
max_norm = 1 # grad clip
44+
warmup_ratio = 0.03
45+
46+
# Save
47+
save_steps = 1000
48+
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
49+
50+
# Evaluate the generation performance during the training
51+
evaluation_freq = 1000
52+
SYSTEM = ''
53+
evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
54+
evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
55+
56+
#######################################################################
57+
# PART 2 Model & Tokenizer & Image Processor #
58+
#######################################################################
59+
tokenizer = dict(
60+
type=AutoTokenizer.from_pretrained,
61+
pretrained_model_name_or_path=llm_name_or_path,
62+
trust_remote_code=True,
63+
padding_side='right')
64+
65+
image_processor = dict(
66+
type=CLIPImageProcessor.from_pretrained,
67+
pretrained_model_name_or_path=visual_encoder_name_or_path,
68+
trust_remote_code=True)
69+
70+
model = dict(
71+
type=LLaVAModel,
72+
freeze_llm=False,
73+
freeze_visual_encoder=True,
74+
pretrained_pth=pretrained_pth,
75+
llm=dict(
76+
type=AutoModelForCausalLM.from_pretrained,
77+
pretrained_model_name_or_path=llm_name_or_path,
78+
trust_remote_code=True),
79+
visual_encoder=dict(
80+
type=CLIPVisionModel.from_pretrained,
81+
pretrained_model_name_or_path=visual_encoder_name_or_path))
82+
83+
#######################################################################
84+
# PART 3 Dataset & Dataloader #
85+
#######################################################################
86+
llava_dataset = dict(
87+
type=LLaVADataset,
88+
data_path=data_path,
89+
image_folder=image_folder,
90+
tokenizer=tokenizer,
91+
image_processor=image_processor,
92+
dataset_map_fn=llava_map_fn,
93+
template_map_fn=dict(
94+
type=template_map_fn_factory, template=prompt_template),
95+
max_length=max_length,
96+
pad_image_to_square=True)
97+
98+
train_dataloader = dict(
99+
batch_size=batch_size,
100+
num_workers=dataloader_num_workers,
101+
pin_memory=True,
102+
dataset=llava_dataset,
103+
sampler=dict(
104+
type=LengthGroupedSampler,
105+
length_property='modality_length',
106+
per_device_batch_size=batch_size * accumulative_counts),
107+
collate_fn=dict(type=default_collate_fn))
108+
109+
#######################################################################
110+
# PART 4 Scheduler & Optimizer #
111+
#######################################################################
112+
# optimizer
113+
optim_wrapper = dict(
114+
type=AmpOptimWrapper,
115+
optimizer=dict(
116+
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
117+
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
118+
accumulative_counts=accumulative_counts,
119+
loss_scale='dynamic',
120+
dtype='float16')
121+
122+
# learning policy
123+
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
124+
param_scheduler = [
125+
dict(
126+
type=LinearLR,
127+
start_factor=1e-5,
128+
by_epoch=True,
129+
begin=0,
130+
end=warmup_ratio * max_epochs,
131+
convert_to_iter_based=True),
132+
dict(
133+
type=CosineAnnealingLR,
134+
eta_min=0.0,
135+
by_epoch=True,
136+
begin=warmup_ratio * max_epochs,
137+
end=max_epochs,
138+
convert_to_iter_based=True)
139+
]
140+
141+
# train, val, test setting
142+
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
143+
144+
#######################################################################
145+
# PART 5 Runtime #
146+
#######################################################################
147+
# Log the dialogue periodically during the training process, optional
148+
custom_hooks = [
149+
dict(type=DatasetInfoHook, tokenizer=tokenizer),
150+
dict(
151+
type=EvaluateChatHook,
152+
tokenizer=tokenizer,
153+
image_processor=image_processor,
154+
every_n_iters=evaluation_freq,
155+
evaluation_inputs=evaluation_inputs,
156+
evaluation_images=evaluation_images,
157+
system=SYSTEM,
158+
prompt_template=prompt_template)
159+
]
160+
161+
# configure default hooks
162+
default_hooks = dict(
163+
# record the time of every iteration.
164+
timer=dict(type=IterTimerHook),
165+
# print log every 10 iterations.
166+
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
167+
# enable the parameter scheduler.
168+
param_scheduler=dict(type=ParamSchedulerHook),
169+
# save checkpoint per `save_steps`.
170+
checkpoint=dict(
171+
type=CheckpointHook,
172+
by_epoch=False,
173+
interval=save_steps,
174+
max_keep_ckpts=save_total_limit),
175+
# set sampler seed in distributed evrionment.
176+
sampler_seed=dict(type=DistSamplerSeedHook),
177+
)
178+
179+
# configure environment
180+
env_cfg = dict(
181+
# whether to enable cudnn benchmark
182+
cudnn_benchmark=False,
183+
# set multi process parameters
184+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
185+
# set distributed parameters
186+
dist_cfg=dict(backend='nccl'),
187+
)
188+
189+
# set visualizer
190+
visualizer = None
191+
192+
# set log level
193+
log_level = 'INFO'
194+
195+
# load from which checkpoint
196+
load_from = None
197+
198+
# whether to resume training from the loaded checkpoint
199+
resume = False
200+
201+
# Defaults to use random seed and disable `deterministic`
202+
randomness = dict(seed=None, deterministic=False)
203+
204+
# set log processor
205+
log_processor = dict(by_epoch=False)

0 commit comments

Comments
 (0)