3939if TYPE_CHECKING :
4040 from transformers import Tokenizer
4141
42+ logger = logging .getLogger (__name__ )
43+
44+
45+ def _parallel_batch_tokenize (tokenizer : Tokenizer , texts : list [str ]) -> list [int ]:
46+ """Batch-tokenize texts using all available cores and return token counts.
47+
48+ Uses a ThreadPoolExecutor to parallelize across ~95% of CPU cores.
49+ HuggingFace tokenizers use a Rust backend that releases the GIL,
50+ so threads achieve real parallelism without GIL contention.
51+ """
52+ from concurrent .futures import ThreadPoolExecutor
53+
54+ n_cores = os .cpu_count () or 1
55+ n_workers = max (1 , int (n_cores * 0.95 ))
56+
57+ if len (texts ) <= n_workers :
58+ # Few texts — just tokenize directly, no threading overhead
59+ encoded = tokenizer (texts , add_special_tokens = False )
60+ return [len (ids ) for ids in encoded ["input_ids" ]]
61+
62+ # Split texts into chunks, one per worker
63+ chunk_size = (len (texts ) + n_workers - 1 ) // n_workers
64+ chunks = [texts [i : i + chunk_size ] for i in range (0 , len (texts ), chunk_size )]
65+
66+ def _tokenize_chunk (chunk : list [str ]) -> list [int ]:
67+ encoded = tokenizer (chunk , add_special_tokens = False )
68+ return [len (ids ) for ids in encoded ["input_ids" ]]
69+
70+ results : list [int ] = []
71+ with ThreadPoolExecutor (max_workers = n_workers ) as pool :
72+ for chunk_lengths in pool .map (_tokenize_chunk , chunks ):
73+ results .extend (chunk_lengths )
74+ return results
75+
4276
4377class TPOTReportingMode (str , Enum ):
4478 """TPOT (Time Per Output Token) reporting mode.
@@ -1034,7 +1068,9 @@ def get_output_sequence_lengths(
10341068 """
10351069 query_result = self .get_sample_outputs ()
10361070
1037- rows = []
1071+ # Collect all texts for batch tokenization
1072+ uuids : list [str ] = []
1073+ texts : list [str ] = []
10381074 for sample_uuid , data_bytes in query_result :
10391075 output_sequence , reasoning_sequence = output_sequence_from_data (data_bytes )
10401076
@@ -1047,13 +1083,16 @@ def get_output_sequence_lengths(
10471083 else :
10481084 full_sequence = output_sequence
10491085
1050- # Tokenize and calculate length
1051- output_tokens = tokenizer .tokenize (full_sequence )
1052- rows .append ((sample_uuid , len (output_tokens )))
1086+ uuids .append (sample_uuid )
1087+ texts .append (full_sequence )
10531088
1054- if not rows :
1089+ if not texts :
10551090 return None
10561091
1092+ # Parallel batch tokenize across ~95% of cores
1093+ token_counts = _parallel_batch_tokenize (tokenizer , texts )
1094+ rows = list (zip (uuids , token_counts , strict = False ))
1095+
10571096 return RollupQueryTable ("output_sequence_length" , None , rows )
10581097
10591098 @profile
@@ -1103,11 +1142,9 @@ def derive_TPOT(
11031142 if not query_result :
11041143 return None
11051144
1106- rows = []
1107- if condense_table and reporting_mode == TPOTReportingMode .TOKEN_WEIGHTED :
1108- repeats = []
1109- else :
1110- repeats = None
1145+ # Pass 1: Collect all non-first-chunk texts for batch tokenization
1146+ batch_uuids : list [str ] = []
1147+ batch_texts : list [str ] = []
11111148
11121149 for sample_uuid , data_bytes in query_result :
11131150 if data_bytes is None or len (data_bytes ) == 0 :
@@ -1155,9 +1192,25 @@ def derive_TPOT(
11551192 # Possible malformed output data where empty string is included as a non-first chunk
11561193 continue
11571194
1158- non_first_tokens = tokenizer .tokenize (non_first_chunk )
1159- n_non_first_tokens = len (non_first_tokens )
1195+ batch_uuids .append (sample_uuid )
1196+ batch_texts .append (non_first_chunk )
1197+
1198+ if not batch_texts :
1199+ return None
1200+
1201+ # Parallel batch tokenize across ~95% of cores
1202+ token_counts = _parallel_batch_tokenize (tokenizer , batch_texts )
1203+
1204+ # Pass 2: Compute TPOT using batch-tokenized results
1205+ rows = []
1206+ if condense_table and reporting_mode == TPOTReportingMode .TOKEN_WEIGHTED :
1207+ repeats = []
1208+ else :
1209+ repeats = None
11601210
1211+ for sample_uuid , n_non_first_tokens in zip (
1212+ batch_uuids , token_counts , strict = False
1213+ ):
11611214 latency = sample_latency_rollup .filter_uuid (sample_uuid , only_first = True )
11621215 if latency is None :
11631216 raise SampleUUIDNotFoundError (sample_uuid , "events record" )
0 commit comments