-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtrain.sh
More file actions
167 lines (156 loc) · 4.72 KB
/
train.sh
File metadata and controls
167 lines (156 loc) · 4.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
NODE_RANK=${1:-0}
##############################
# training settings
##############################
# Path of training data
DATA_PATH=${DATA_PATH:-"./data/osworld_test_all.jsonl"}
# Path of base model
MODEL_PATH=${MODEL_PATH:="ByteDance-Seed/UI-TARS-7B-DPO"}
# env setting
# modify according to your OSWorld API manager server
ENV_URL=${ENV_URL:-"http://0.0.0.0"}
ENV_MANAGER_PORT=${ENV_MANAGER_PORT:-10001}
# training hyper-param
EPISODE=${EPISODE:-1}
TRAIN_STEP=${TRAIN_STEP:-1000}
RBS=${RBS:-1} # rollout minibatch
N_SAMPLES=${N_SAMPLE:-64}
R_TARGET_SIZE=${R_TARGET_SIZE:-16384} # num sequence to collect per rolout step
TBS=${TBS:-16384} # one update per rollout, TBS = R_TARGET_SIZE
MAX_GEN_BATCH=${MAX_GEN_BATCH:--1}
N_GROUPS=${N_GROUPS:-1}
KL_TYPE=${KL_TYPE:-"mse"}
KL=${KL:-1e-1}
LR=${LR:-2e-6}
LR_SCHEDULE=${LR_SCHEDULE:-"constant_with_warmup"}
WARMUP=${WARMUP:-0.0}
MAX_LENGTH=${MAX_LENGTH:-256}
export MIN_PIXELS=3136
export MAX_PIXELS=2116800
# llm eval
# modify according to your local deployment
API_TYPE=${API_TYPE:-"qwen"}
API_MODEL=${API_MODEL:-"Qwen2.5-VL-32B-Instruct"}
API_BASE_URL=${API_BASE_URL:-"http://0.0.0.0:21101"}
API_KEY=${API_KEY:-"empty"}
EVAL_PROMPT_FILE=${EVAL_PROMPT_FILE:-"osworld_llm_eval_v1.json"}
EVAL_TEMP=${EVAL_TEMP:-1.0}
VOTING_TYPE=${VOTING_TYPE:-"all"}
VOTING_NUM=${VOTING_NUM:-4}
# sampling setting
TEMP=${TEMP:-0.5}
TOP_P=${TOP_P:-0.9}
FREQ_PEN=${FREQ_PEN:-1}
# save & log
SAVE_MODEL_NAME=osworld_test-time
SAVE_DIR=results/train/$SAVE_MODEL_NAME
mkdir -p $SAVE_DIR
mkdir -p $SAVE_DIR/trajectory
MAX_CKPT_NUM=${MAX_CKPT_NUM:-10}
##############################
# clean all existing remote envs
IFS=',' read -ra URL_LIST <<< "$ENV_URL"
NUM_URLS=${#URL_LIST[@]}
if [[ $NODE_RANK -eq 0 ]]; then
for (( i=0; i<$NUM_URLS; i+=1 )) do
url=${URL_LIST[$i]}
curl -X POST $url:$ENV_MANAGER_PORT/clean
done
fi
# ray setting
export RAY_ADDRESS="http://127.0.0.1:$DASHBORAD_PORT"
export NUMEXPR_MAX_THREADS=128
export RAY_DEDUP_LOGS=0
if [ "$NODE_RANK" = "0" ]; then
PYTHONPATH=./:$PYTHONPATH \
ray job submit \
-- python3 -m openrlhf.cli.train_ppo_ray \
--ref_num_nodes $NNODES \
--ref_num_gpus_per_node 8 \
--actor_num_nodes $NNODES \
--actor_num_gpus_per_node 8 \
--vllm_num_engines $N_ENGINES \
--vllm_tensor_parallel_size $ENGINE_TP \
--enforce_eager \
--pretrain ${MODEL_PATH} \
--remote_rm_url empty \
--save_path $SAVE_DIR \
--ckpt_path $SAVE_DIR \
--micro_train_batch_size 1 \
--train_batch_size ${TBS} \
--micro_rollout_batch_size 1 \
--rollout_batch_size ${RBS} \
--advantage_estimator group_norm \
--use_dapo_trainer \
--dapo_dynamic_sampling \
--rollout_target_size ${R_TARGET_SIZE} \
--max_num_gen_batches ${MAX_GEN_BATCH} \
--max_samples 100000 \
--max_epochs 1 \
--num_episodes ${EPISODE} \
--num_train_steps ${TRAIN_STEP} \
--lr_warmup_ratio ${WARMUP} \
--n_samples_per_prompt $N_SAMPLES \
--prompt_max_len 20480 \
--generate_max_len $MAX_LENGTH \
--zero_stage 3 \
--bf16 \
--actor_learning_rate $LR \
--critic_learning_rate 9e-6 \
--actor_lr_schedule $LR_SCHEDULE \
--init_kl_coef $KL \
--kl_loss_coef $KL \
--kl_penalty_type $KL_TYPE \
--kl_threshold_type advantage \
--not_normalize_advantage \
--prompt_data $DATA_PATH \
--simple_load_dataset \
--packing_samples \
--flash_attn \
--gradient_checkpointing \
--save_steps 1 \
--save_hf_model \
--wandb_run_name $SAVE_MODEL_NAME \
--use_tensorboard tb_log \
--vllm_sync_backend nccl \
--max_ckpt_num $MAX_CKPT_NUM \
--group_method normal \
--use_length_reward_in_efficiency \
--temperature $TEMP \
--top_p $TOP_P \
--frequency_penalty $FREQ_PEN \
--overlap_comm \
--train_agent \
--task_group_distributed \
--num_distributed_groups $N_GROUPS \
--data_gather_redistribute \
--env_type osworld \
--env_url $ENV_URL \
--env_manager_port $ENV_MANAGER_PORT \
--test_task_llm_eval \
--action_space pyautogui \
--observation_type screenshot \
--agent_max_steps 15 \
--save_trajectory \
--agent_type uitars \
--agent_action_space computer \
--agent_prompt_language Chinese \
--num_history 5 \
--num_input_image 5 \
--use_llm_evaluator \
--api_type $API_TYPE \
--api_model $API_MODEL \
--api_base_url $API_BASE_URL \
--api_key $API_KEY \
--eval_prompt_file $EVAL_PROMPT_FILE \
--llm_eval_temperature $EVAL_TEMP \
--llm_eval_voting_type $VOTING_TYPE \
--llm_eval_voting_num $VOTING_NUM \
--load_checkpoint \
--colocate_all_models \
--vllm_enable_sleep \
--vllm_gpu_memory_utilization 0.6 \
--deepspeed_enable_sleep \
${PY_ARGS} \
2>&1 | tee $SAVE_DIR/train.log
fi