Skip to content

Commit 16f7f90

Browse files
authored
Merge branch 'main' into feat-add-orpo-support
2 parents 4db2198 + 705dc32 commit 16f7f90

File tree

17 files changed

+1725
-99
lines changed

17 files changed

+1725
-99
lines changed

examples/deepscaler/math_eval_nb.py

Lines changed: 612 additions & 0 deletions
Large diffs are not rendered by default.

examples/deepscaler/train_deepscaler.ipynb

Lines changed: 434 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ classifiers = [
2020
dependencies = [
2121
"datasets",
2222
"flax>=0.11.1",
23+
"fsspec", # gcsfs dependency
2324
"gcsfs",
2425
"grain",
2526
"huggingface_hub",
@@ -28,9 +29,11 @@ dependencies = [
2829
"kagglehub",
2930
"numba",
3031
"omegaconf", # CLI config
32+
"pylatexenc", # Eval result parsing
3133
"python-dotenv", # Huggingface API key
3234
"qwix",
3335
"sentencepiece",
36+
"sympy", # Eval result parsing
3437
"tensorboardX",
3538
"tensorflow_datasets",
3639
"tqdm",

scripts/grpo_demo_llama3_qwen2.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@
8585
type=int,
8686
default=1869,
8787
required=False,
88-
help="Number of batches for training.",
88+
help=(
89+
"Number of batches for training. Defaults to total number of samples //"
90+
" global batch size."
91+
),
8992
)
9093
parser.add_argument(
9194
"--num-test-batches",
@@ -94,6 +97,27 @@
9497
required=False,
9598
help="Number of test batches for evaluation.",
9699
)
100+
parser.add_argument(
101+
"--global-batch-size",
102+
type=int,
103+
default=4,
104+
required=False,
105+
help="Number of global batches for learning.",
106+
)
107+
parser.add_argument(
108+
"--train-micro-batch-size",
109+
type=int,
110+
default=2,
111+
required=False,
112+
help="Number of micro batches for training.",
113+
)
114+
parser.add_argument(
115+
"--train-mini-batch-size",
116+
type=int,
117+
default=4,
118+
required=False,
119+
help="Number of mini batches for training.",
120+
)
97121
parser.add_argument(
98122
"--rollout-engine",
99123
type=str,
@@ -163,7 +187,7 @@
163187
# ====== GRPO ======
164188
# === Generation during GRPO training ===
165189
MAX_PROMPT_LENGTH = 256
166-
TOTAL_GENERATION_STEPS = 1024 # YY 768
190+
TOTAL_GENERATION_STEPS = 768
167191
# Important to keep a high-ish temperature for varied, diverse responses during
168192
# training.
169193
TEMPERATURE = 0.9
@@ -186,17 +210,14 @@
186210
EPSILON = 0.2
187211

188212
# ====== Training ======
189-
# 2 is the max we can do on v5e-8 with llama3 8B model.
190-
# 4 is the max we can do on v5e-8 with llama3 1B model.
191-
TRAIN_MICRO_BATCH_SIZE = 4
192213
# To speed up for quick workflow validation, we can change NUM_BATCHES to e.g. 2
193-
NUM_BATCHES = args.num_batches
214+
NUM_BATCHES = min(args.num_batches, 7473 // args.global_batch_size)
194215
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
195216
# increased to a max. of 330 (if batch size is 4).
196217
# To speed up for quick workflow validation, we can change it to e.g. 1
197218
NUM_TEST_BATCHES = args.num_test_batches
198219

199-
EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
220+
EVAL_EVERY_N_STEPS = 1000 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
200221
NUM_EPOCHS = 1 # can potentially train for more epochs
201222

202223
# Number of training steps.
@@ -344,7 +365,7 @@ def get_dataset(path: str) -> grain.MapDataset:
344365
return loaded_dataset
345366

346367

347-
dataset = get_dataset(TRAIN_DATA_PATH).batch(TRAIN_MICRO_BATCH_SIZE)[
368+
dataset = get_dataset(TRAIN_DATA_PATH).batch(args.global_batch_size)[
348369
:NUM_BATCHES
349370
]
350371

@@ -357,7 +378,7 @@ def get_dataset(path: str) -> grain.MapDataset:
357378

358379
val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)
359380

360-
test_dataset = get_dataset(TEST_DATA_PATH).batch(TRAIN_MICRO_BATCH_SIZE)[
381+
test_dataset = get_dataset(TEST_DATA_PATH).batch(args.global_batch_size)[
361382
:NUM_TEST_BATCHES
362383
]
363384

@@ -627,7 +648,7 @@ def generate(
627648

628649
out_data = sampler(
629650
input_strings=input_batch,
630-
max_generation_steps=768,
651+
max_generation_steps=TOTAL_GENERATION_STEPS,
631652
temperature=temperature,
632653
top_k=top_k,
633654
top_p=top_p,
@@ -782,8 +803,8 @@ def evaluate(
782803
actor_optimizer=optimizer,
783804
eval_every_n_steps=EVAL_EVERY_N_STEPS,
784805
max_steps=MAX_STEPS,
785-
mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
786-
train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
806+
mini_batch_size=args.train_mini_batch_size,
807+
train_micro_batch_size=args.train_micro_batch_size,
787808
# metrics logging
788809
metrics_logging_options=metrics_logging_options,
789810
# checkpoint saving
@@ -802,7 +823,6 @@ def evaluate(
802823
rollout_vllm_tpu_backend_type="jax",
803824
rollout_vllm_server_mode=args.rollout_server_mode,
804825
),
805-
806826
)
807827

808828
grpo_config = grpo_learner.GRPOConfig(

tests/generate/sglang_jax_sampler_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_sglang_jax_sampler(self):
162162
self.assertTrue(
163163
np.allclose(
164164
tunix_state["embedder"]["input_embedding"].value,
165-
sglangjax_state["transformer"]["embed_tokens"]["embedding"].value,
165+
sglangjax_state["model"]["embed_tokens"]["embedding"].value,
166166
)
167167
)
168168

tunix/generate/sglang_jax_sampler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def _sglang_jax_config(self, config: SglangJaxConfig):
108108
args["model_path"] = config.model_version
109109
args["precompile_bs_paddings"] = [1, 64]
110110
args["precompile_token_paddings"] = [8192]
111-
args["disable_jax_precompile"] = True
112111
args["page_size"] = 64
113112
args["context_length"] = config.context_length
114113
args["tp_size"] = self._find_tp_size(config.mesh)

tunix/models/llama3/mapping_sglang_jax.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,46 +13,46 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
1313
return {
1414
'lm_head.w': ('lm_head.embedding', (None, 'model')),
1515
'embedder.input_embedding': (
16-
'transformer.embed_tokens.embedding',
16+
'model.embed_tokens.embedding',
1717
('model', None),
1818
),
1919
'layers.*.input_layernorm.w': (
20-
'transformer.layers.*.input_layernorm.scale',
20+
'model.layers.*.input_layernorm.scale',
2121
(None,),
2222
),
2323
'layers.*.mlp.down_proj.kernel': (
24-
'transformer.layers.*.mlp.down_proj.weight',
24+
'model.layers.*.mlp.down_proj.weight',
2525
('model', None),
2626
),
2727
'layers.*.mlp.gate_proj.kernel': (
28-
'transformer.layers.*.mlp.gate_proj.weight',
28+
'model.layers.*.mlp.gate_proj.weight',
2929
(None, 'model'),
3030
),
3131
'layers.*.mlp.up_proj.kernel': (
32-
'transformer.layers.*.mlp.up_proj.weight',
32+
'model.layers.*.mlp.up_proj.weight',
3333
(None, 'model'),
3434
),
3535
'layers.*.post_attention_layernorm.w': (
36-
'transformer.layers.*.post_attention_layernorm.scale',
36+
'model.layers.*.post_attention_layernorm.scale',
3737
(None,),
3838
),
3939
'layers.*.attn.k_proj.w': (
40-
'transformer.layers.*.self_attn.k_proj.weight',
40+
'model.layers.*.self_attn.k_proj.weight',
4141
(None, 'model', None),
4242
),
4343
'layers.*.attn.o_proj.w': (
44-
'transformer.layers.*.self_attn.o_proj.weight',
44+
'model.layers.*.self_attn.o_proj.weight',
4545
('model', None, None),
4646
),
4747
'layers.*.attn.q_proj.w': (
48-
'transformer.layers.*.self_attn.q_proj.weight',
48+
'model.layers.*.self_attn.q_proj.weight',
4949
(None, 'model', None),
5050
),
5151
'layers.*.attn.v_proj.w': (
52-
'transformer.layers.*.self_attn.v_proj.weight',
52+
'model.layers.*.self_attn.v_proj.weight',
5353
(None, 'model', None),
5454
),
55-
'final_norm.w': ('transformer.norm.scale', (None,)),
55+
'final_norm.w': ('model.norm.scale', (None,)),
5656
}
5757

5858

tunix/models/qwen2/mapping_sglang_jax.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,58 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
1313
return {
1414
'lm_head.w': ('lm_head.embedding', (None, 'model')),
1515
'embedder.input_embedding': (
16-
'transformer.embed_tokens.embedding',
16+
'model.embed_tokens.embedding',
1717
('model', None),
1818
),
1919
'layers.*.input_layernorm.w': (
20-
'transformer.layers.*.input_layernorm.scale',
20+
'model.layers.*.input_layernorm.scale',
2121
(None,),
2222
),
2323
'layers.*.mlp.down_proj.kernel': (
24-
'transformer.layers.*.mlp.down_proj.weight',
24+
'model.layers.*.mlp.down_proj.weight',
2525
('model', None),
2626
),
2727
'layers.*.mlp.gate_proj.kernel': (
28-
'transformer.layers.*.mlp.gate_proj.weight',
28+
'model.layers.*.mlp.gate_proj.weight',
2929
(None, 'model'),
3030
),
3131
'layers.*.mlp.up_proj.kernel': (
32-
'transformer.layers.*.mlp.up_proj.weight',
32+
'model.layers.*.mlp.up_proj.weight',
3333
(None, 'model'),
3434
),
3535
'layers.*.post_attention_layernorm.w': (
36-
'transformer.layers.*.post_attention_layernorm.scale',
36+
'model.layers.*.post_attention_layernorm.scale',
3737
(None,),
3838
),
3939
'layers.*.attn.k_proj.w': (
40-
'transformer.layers.*.self_attn.k_proj.weight',
40+
'model.layers.*.self_attn.k_proj.weight',
4141
(None, 'model', None),
4242
),
4343
'layers.*.attn.k_bias': (
44-
'transformer.layers.*.self_attn.k_proj.bias',
44+
'model.layers.*.self_attn.k_proj.bias',
4545
(None,),
4646
),
4747
'layers.*.attn.o_proj.w': (
48-
'transformer.layers.*.self_attn.o_proj.weight',
48+
'model.layers.*.self_attn.o_proj.weight',
4949
('model', None, None),
5050
),
5151
'layers.*.attn.q_proj.w': (
52-
'transformer.layers.*.self_attn.q_proj.weight',
52+
'model.layers.*.self_attn.q_proj.weight',
5353
(None, 'model', None),
5454
),
5555
'layers.*.attn.q_bias': (
56-
'transformer.layers.*.self_attn.q_proj.bias',
56+
'model.layers.*.self_attn.q_proj.bias',
5757
(None,),
5858
),
5959
'layers.*.attn.v_proj.w': (
60-
'transformer.layers.*.self_attn.v_proj.weight',
60+
'model.layers.*.self_attn.v_proj.weight',
6161
(None, 'model', None),
6262
),
6363
'layers.*.attn.v_bias': (
64-
'transformer.layers.*.self_attn.v_proj.bias',
64+
'model.layers.*.self_attn.v_proj.bias',
6565
(None,),
6666
),
67-
'final_norm.w': ('transformer.norm.scale', (None,)),
67+
'final_norm.w': ('model.norm.scale', (None,)),
6868
}
6969

7070

tunix/models/safetensors_loader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
import jax.numpy as jnp
2727
import safetensors.flax as safetensors
2828

29+
# DO NOT CHNAGE THIS IMPORT. This is used in both oss and GOOGLE_INTERNAL_PACKAGE_PATH.
30+
from tunix.oss import utils
31+
32+
load_file_from_gcs = utils.load_file_from_gcs
33+
2934

3035
def torch_key_to_jax_key(mapping, source_key):
3136
"""Convert torch key to jax key using the provided mapping."""
@@ -78,6 +83,9 @@ def load_and_create_model(
7883
Returns:
7984
Model instance with loaded weights
8085
"""
86+
87+
file_dir = load_file_from_gcs(file_dir)
88+
8189
files = list(epath.Path(file_dir).expanduser().glob("*.safetensors"))
8290

8391
if not files:

tunix/oss/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import os
1717

18+
import fsspec
19+
1820

1921
def pathways_available() -> bool:
2022
if "proxy" not in os.getenv("JAX_PLATFORMS", ""):
@@ -25,3 +27,30 @@ def pathways_available() -> bool:
2527
return True
2628
except ImportError:
2729
return False
30+
31+
32+
def load_file_from_gcs(file_dir: str, target_dir: str = None) -> str:
33+
"""Load file from GCS."""
34+
if file_dir.startswith("/"):
35+
return file_dir
36+
37+
if not file_dir.startswith("gs://"):
38+
raise ValueError(f"Invalid GCS path: {file_dir}")
39+
40+
_, prefix = file_dir[5:].split("/", 1)
41+
try:
42+
import tempfile # pylint: disable=g-import-not-at-top
43+
44+
if target_dir is None:
45+
target_dir = tempfile.gettempdir()
46+
local_dir = os.path.join(target_dir, prefix)
47+
48+
fsspec_fs = fsspec.filesystem("gs")
49+
fsspec_fs.get(file_dir, local_dir, recursive=True)
50+
51+
return local_dir
52+
except ImportError as e:
53+
raise ImportError(
54+
"Please install google-cloud-storage to load model from GCS."
55+
) from e
56+

0 commit comments

Comments
 (0)