Skip to content

Commit faf4d72

Browse files
authored
Merge branch 'google:main' into feat-add-orpo-support
2 parents 16f7f90 + dbe1227 commit faf4d72

27 files changed

+2362
-607
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This workflow will build tunix python package and run tpu regression tests.
16+
17+
name: Tunix Nightly Regression Tests
18+
19+
on:
20+
workflow_dispatch:
21+
schedule:
22+
# Run the job every day at 2am
23+
- cron: '0 2 * * *'
24+
25+
concurrency:
26+
# Dedup scheduled runs but nothing else
27+
group: >
28+
${{
29+
github.event_name == 'schedule' && format('{0}-schedule', github.workflow) ||
30+
github.run_id
31+
}}
32+
cancel-in-progress: false
33+
34+
permissions:
35+
contents: read
36+
jobs:
37+
build_tunix_package:
38+
name: Build tunix package
39+
uses: ./.github/workflows/build_package.yml
40+
41+
tunix_tpu_nightly_regression:
42+
needs: build_tunix_package
43+
uses: ./.github/workflows/tpu-nightly-regression.yml
44+
secrets:
45+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
46+
47+
notify_failure:
48+
name: Notify failed build # creates an issue or modifies last open existing issue for failed build
49+
needs: [build_tunix_package, tunix_tpu_nightly_regression]
50+
if: ${{ always() }}
51+
runs-on: ubuntu-latest
52+
permissions:
53+
issues: write
54+
steps:
55+
- name: Check whether one of the jobs failed
56+
if: ${{ contains(needs.*.result, 'failure') && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }}
57+
uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0
58+
with:
59+
github-token: ${{ secrets.GITHUB_TOKEN }}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
16+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
17+
18+
name: Tunix Nightly Regression
19+
20+
on:
21+
workflow_call:
22+
secrets:
23+
HF_TOKEN:
24+
required: true
25+
description: 'HuggingFace token for model downloads'
26+
27+
concurrency:
28+
# Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else
29+
group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }}
30+
cancel-in-progress: true
31+
32+
env:
33+
HF_HOME: ~/.cache/huggingface
34+
HF_HUB_ENABLE_HF_TRANSFER: "1"
35+
36+
jobs:
37+
run_prod:
38+
runs-on: [linux-x86-ct5lp-224-8tpu]
39+
environment: testing
40+
container:
41+
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:jax0.7.1_rev1
42+
options: --privileged
43+
env:
44+
CLOUD_TPU_ACCELERATOR: v5e-8
45+
JAX_PLATFORMS: tpu
46+
steps:
47+
48+
# Cache Hugging Face hub
49+
- name: Cache HF hub
50+
uses: actions/cache@v4
51+
with:
52+
path: ~/.cache/huggingface
53+
key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
54+
restore-keys: |
55+
hf-${{ runner.os }}-
56+
57+
- name: Checkout code
58+
uses: actions/checkout@v4
59+
with:
60+
fetch-depth: 0
61+
62+
- name: Install tunix dependencies
63+
run: |
64+
pip install -e .[prod]
65+
pip install pytest pytest-xdist
66+
67+
- name: Verify TPU availability
68+
run: |
69+
python -c "
70+
import jax
71+
print(f'JAX version: {jax.__version__}')
72+
print(f'JAX devices: {jax.devices()}')
73+
74+
# Check if we have TPU devices specifically
75+
devices = jax.devices()
76+
has_tpu = len(devices) > 0 and all(device.platform == 'tpu' for device in devices)
77+
print(f'TPU available: {has_tpu}')
78+
79+
if not has_tpu:
80+
print('ERROR: No TPU devices found! Expected TPU devices but got:', [device.platform for device in devices])
81+
exit(1)
82+
else:
83+
print(f'SUCCESS: Found {len(devices)} TPU device(s)')
84+
"
85+
86+
- name: Run regression scripts
87+
id: regression_tests
88+
run: |
89+
FAILED=0
90+
echo "Running tunix/oss/examples/deepscaler/math_eval_nb.py..."
91+
python tunix/oss/examples/deepscaler/math_eval_nb.py || FAILED=1
92+
93+
echo "Running tunix/oss/scripts/grpo_demo_llama3_qwen2.py..."
94+
python tunix/oss/scripts/grpo_demo_llama3_qwen2.py || FAILED=1
95+
96+
if [ "$FAILED" -ne 0 ]; then
97+
echo "One or more scripts failed!"
98+
exit 1
99+
fi
100+
101+
102+

.github/workflows/tpu-tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ jobs:
6161

6262
- name: Install tunix dependencies
6363
run: |
64-
pip install -e .[prod]
64+
pip install --upgrade pip
65+
pip install -e .[prod] --force-reinstall
6566
pip install pytest pytest-xdist
6667
6768
- name: Verify TPU availability

examples/deepscaler/math_eval_nb.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# %%
22
from pprint import pprint
3-
from datasets import Dataset
3+
import datasets as datasets_lib
44
import grain
55
import pandas as pd
66
import os
77
import fsspec
88

9-
from transformers import AutoTokenizer
9+
import transformers
1010
from tunix.generate import mappings
1111

12+
Dataset = datasets_lib.Dataset
13+
AutoTokenizer = transformers.AutoTokenizer
1214

1315
try:
1416
from GOOGLE_INTERNAL_PACKAGE_PATH.pyglib import gfile
@@ -38,15 +40,15 @@
3840
from tunix.generate import sampler as sampler_lib
3941
from tunix.utils import math_utils
4042
# %%
41-
from typing import Any, Dict
43+
from typing import Any, Dict, Optional
4244
import jax
4345
from tqdm.auto import tqdm
4446
import re
4547

4648
# Only used for Math500
4749
def extract_answer_robust(passage: str) -> str:
4850
if not passage:
49-
return None
51+
return ""
5052

5153
# Pattern 1: Look for \boxed{...} with proper matching braces
5254
# This handles nested braces like \boxed{\frac{1}{2}}
@@ -107,7 +109,7 @@ def extract_answer_robust(passage: str) -> str:
107109
break
108110
return answer.strip().rstrip(".,;:)")
109111

110-
return None
112+
return ""
111113
# %%
112114

113115
# only used for AIME-2024
@@ -160,10 +162,6 @@ def evaluate_correctness(response: Any, ground_truths: Any) -> bool:
160162
return False
161163
# %%
162164

163-
from transformers import AutoTokenizer
164-
from pprint import pprint
165-
import grain
166-
167165
class Qwen25MathEvaluator:
168166

169167
def __init__(
@@ -228,20 +226,20 @@ def load_model(self):
228226
)
229227

230228
if self.sampler_type == "vanilla":
231-
self.sampler = sampler_lib.Sampler(
229+
self.sampler_vanilla = sampler_lib.Sampler(
232230
transformer=self.model,
233231
tokenizer=self.tokenizer,
234232
cache_config=cache_config,
235233
)
236234
elif self.sampler_type == "sglang-jax":
237-
from tunix.generate import sglang_jax_sampler # pylint: disable=g-import-not-at-top
235+
from tunix.google.stubs import sglang_jax_sampler_stub as sglang_jax_sampler # pylint: disable=g-import-not-at-top
238236

239237
mapping_config = mappings.MappingConfig.build(
240238
mapping_obj=None,
241239
model=self.model,
242240
backend="sglang_jax",
243241
)
244-
self.sampler = sglang_jax_sampler.SglangJaxSampler(
242+
self.sampler_sglang = sglang_jax_sampler.SglangJaxSampler(
245243
tokenizer=self.tokenizer,
246244
config=sglang_jax_sampler.SglangJaxConfig(
247245
mesh=self.mesh,
@@ -328,8 +326,12 @@ def generate(
328326
temperature: float = 0.6,
329327
top_k: int = 50,
330328
top_p: float = 0.95,
331-
seed: int = None,
329+
seed: int | None = None,
332330
) -> str:
331+
if self.tokenizer is None:
332+
raise RuntimeError(
333+
"Model components not loaded. Call load_model() first."
334+
)
333335
max_length = max(len(self.tokenizer.encode(p)) for p in prompts)
334336
cache_size = self.max_prompt_length + self.max_generation_steps + 100
335337
safe_gen_length = min(
@@ -346,7 +348,7 @@ def generate(
346348

347349
# Generate
348350
if self.sampler_type == "vanilla":
349-
out_data = self.sampler(
351+
out_data = self.sampler_vanilla(
350352
input_strings=prompts,
351353
max_generation_steps=safe_gen_length,
352354
temperature=temperature,
@@ -357,7 +359,7 @@ def generate(
357359
seed=jax.random.PRNGKey(seed) if seed is not None else None,
358360
)
359361
elif self.sampler_type == "sglang-jax":
360-
out_data = self.sampler(
362+
out_data = self.sampler_sglang(
361363
input_strings=prompts,
362364
max_generation_steps=safe_gen_length,
363365
max_prompt_length=self.max_prompt_length,
@@ -370,22 +372,22 @@ def generate(
370372
)
371373
else:
372374
raise ValueError(f"Unsupported sampler type: {self.sampler_type}")
373-
return out_data.text
375+
return out_data.text[0]
374376

375377
def evaluate(
376378
self,
377379
batch_size: int = 8,
378-
num_batches: int = None,
380+
num_batches: int | None = None,
379381
temperature: float = 0.6,
380-
top_k: int = 50,
381-
top_p: float = 0.95,
382+
top_k: Optional[int] = 50,
383+
top_p: Optional[float] = 0.95,
382384
num_passes: int = 1,
383385
debug_first_n: int = 3, # NEW: Debug first N examples
384386
) -> Dict[str, Any]:
385387
print("=" * 60)
386388
print("Starting Evaluation")
387389
print("=" * 60)
388-
print(f"Configuration:")
390+
print("Configuration:")
389391
print(f" Batch size: {batch_size}")
390392
print(f" Num batches: {num_batches or 'all'}")
391393
print(f" Temperature: {temperature}")
@@ -467,7 +469,8 @@ def evaluate(
467469
print(f"Ground truth: {answer}")
468470
print("=" * 60 + "\n")
469471
print(f"Prompt (first 300 chars): {prompt[:]}")
470-
print(f"Prompt length: {len(self.tokenizer.encode(prompt))} tokens")
472+
if self.tokenizer is not None and hasattr(self.tokenizer, "encode"):
473+
print(f"Prompt length: {len(self.tokenizer.encode(prompt))} tokens")
471474
print("=" * 60 + "\n")
472475
for i, (response, ans, cor) in enumerate(
473476
zip(responses, extracted_answers, answer_correct)
@@ -553,7 +556,7 @@ def evaluate(
553556
print("\nStarting evaluation...")
554557
results = evaluator.evaluate(
555558
batch_size=8,
556-
# num_batches=3,
559+
num_batches=None,
557560
temperature=0.6,
558561
top_k=50,
559562
top_p=0.95,
@@ -592,7 +595,7 @@ def evaluate(
592595

593596
results = evaluator.evaluate(
594597
batch_size=1,
595-
# num_batches=3,
598+
num_batches=None,
596599
temperature=0.6,
597600
top_k=None,
598601
top_p=0.95,

0 commit comments

Comments
 (0)