11# %%
22from pprint import pprint
3- from datasets import Dataset
3+ import datasets as datasets_lib
44import grain
55import pandas as pd
66import os
77import fsspec
88
9- from transformers import AutoTokenizer
9+ import transformers
1010from tunix .generate import mappings
1111
12+ Dataset = datasets_lib .Dataset
13+ AutoTokenizer = transformers .AutoTokenizer
1214
1315try :
1416 from GOOGLE_INTERNAL_PACKAGE_PATH .pyglib import gfile
3840 from tunix .generate import sampler as sampler_lib
3941 from tunix .utils import math_utils
4042# %%
41- from typing import Any , Dict
43+ from typing import Any , Dict , Optional
4244import jax
4345from tqdm .auto import tqdm
4446import re
4547
4648# Only used for Math500
4749def extract_answer_robust (passage : str ) -> str :
4850 if not passage :
49- return None
51+ return ""
5052
5153 # Pattern 1: Look for \boxed{...} with proper matching braces
5254 # This handles nested braces like \boxed{\frac{1}{2}}
@@ -107,7 +109,7 @@ def extract_answer_robust(passage: str) -> str:
107109 break
108110 return answer .strip ().rstrip (".,;:)" )
109111
110- return None
112+ return ""
111113# %%
112114
113115# only used for AIME-2024
@@ -160,10 +162,6 @@ def evaluate_correctness(response: Any, ground_truths: Any) -> bool:
160162 return False
161163# %%
162164
163- from transformers import AutoTokenizer
164- from pprint import pprint
165- import grain
166-
167165class Qwen25MathEvaluator :
168166
169167 def __init__ (
@@ -228,20 +226,20 @@ def load_model(self):
228226 )
229227
230228 if self .sampler_type == "vanilla" :
231- self .sampler = sampler_lib .Sampler (
229+ self .sampler_vanilla = sampler_lib .Sampler (
232230 transformer = self .model ,
233231 tokenizer = self .tokenizer ,
234232 cache_config = cache_config ,
235233 )
236234 elif self .sampler_type == "sglang-jax" :
237- from tunix .generate import sglang_jax_sampler # pylint: disable=g-import-not-at-top
235+ from tunix .google . stubs import sglang_jax_sampler_stub as sglang_jax_sampler # pylint: disable=g-import-not-at-top
238236
239237 mapping_config = mappings .MappingConfig .build (
240238 mapping_obj = None ,
241239 model = self .model ,
242240 backend = "sglang_jax" ,
243241 )
244- self .sampler = sglang_jax_sampler .SglangJaxSampler (
242+ self .sampler_sglang = sglang_jax_sampler .SglangJaxSampler (
245243 tokenizer = self .tokenizer ,
246244 config = sglang_jax_sampler .SglangJaxConfig (
247245 mesh = self .mesh ,
@@ -328,8 +326,12 @@ def generate(
328326 temperature : float = 0.6 ,
329327 top_k : int = 50 ,
330328 top_p : float = 0.95 ,
331- seed : int = None ,
329+ seed : int | None = None ,
332330 ) -> str :
331+ if self .tokenizer is None :
332+ raise RuntimeError (
333+ "Model components not loaded. Call load_model() first."
334+ )
333335 max_length = max (len (self .tokenizer .encode (p )) for p in prompts )
334336 cache_size = self .max_prompt_length + self .max_generation_steps + 100
335337 safe_gen_length = min (
@@ -346,7 +348,7 @@ def generate(
346348
347349 # Generate
348350 if self .sampler_type == "vanilla" :
349- out_data = self .sampler (
351+ out_data = self .sampler_vanilla (
350352 input_strings = prompts ,
351353 max_generation_steps = safe_gen_length ,
352354 temperature = temperature ,
@@ -357,7 +359,7 @@ def generate(
357359 seed = jax .random .PRNGKey (seed ) if seed is not None else None ,
358360 )
359361 elif self .sampler_type == "sglang-jax" :
360- out_data = self .sampler (
362+ out_data = self .sampler_sglang (
361363 input_strings = prompts ,
362364 max_generation_steps = safe_gen_length ,
363365 max_prompt_length = self .max_prompt_length ,
@@ -370,22 +372,22 @@ def generate(
370372 )
371373 else :
372374 raise ValueError (f"Unsupported sampler type: { self .sampler_type } " )
373- return out_data .text
375+ return out_data .text [ 0 ]
374376
375377 def evaluate (
376378 self ,
377379 batch_size : int = 8 ,
378- num_batches : int = None ,
380+ num_batches : int | None = None ,
379381 temperature : float = 0.6 ,
380- top_k : int = 50 ,
381- top_p : float = 0.95 ,
382+ top_k : Optional [ int ] = 50 ,
383+ top_p : Optional [ float ] = 0.95 ,
382384 num_passes : int = 1 ,
383385 debug_first_n : int = 3 , # NEW: Debug first N examples
384386 ) -> Dict [str , Any ]:
385387 print ("=" * 60 )
386388 print ("Starting Evaluation" )
387389 print ("=" * 60 )
388- print (f "Configuration:" )
390+ print ("Configuration:" )
389391 print (f" Batch size: { batch_size } " )
390392 print (f" Num batches: { num_batches or 'all' } " )
391393 print (f" Temperature: { temperature } " )
@@ -467,7 +469,8 @@ def evaluate(
467469 print (f"Ground truth: { answer } " )
468470 print ("=" * 60 + "\n " )
469471 print (f"Prompt (first 300 chars): { prompt [:]} " )
470- print (f"Prompt length: { len (self .tokenizer .encode (prompt ))} tokens" )
472+ if self .tokenizer is not None and hasattr (self .tokenizer , "encode" ):
473+ print (f"Prompt length: { len (self .tokenizer .encode (prompt ))} tokens" )
471474 print ("=" * 60 + "\n " )
472475 for i , (response , ans , cor ) in enumerate (
473476 zip (responses , extracted_answers , answer_correct )
@@ -553,7 +556,7 @@ def evaluate(
553556print ("\n Starting evaluation..." )
554557results = evaluator .evaluate (
555558 batch_size = 8 ,
556- # num_batches=3 ,
559+ num_batches = None ,
557560 temperature = 0.6 ,
558561 top_k = 50 ,
559562 top_p = 0.95 ,
@@ -592,7 +595,7 @@ def evaluate(
592595
593596results = evaluator .evaluate (
594597 batch_size = 1 ,
595- # num_batches=3 ,
598+ num_batches = None ,
596599 temperature = 0.6 ,
597600 top_k = None ,
598601 top_p = 0.95 ,
0 commit comments