Skip to content

Commit 469355c

Browse files
committed
add parallel tokenizer
1 parent 076a0db commit 469355c

3 files changed

Lines changed: 78 additions & 17 deletions

File tree

src/inference_endpoint/load_generator/session.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,22 @@ def _run_test(
9393
)
9494

9595
for _ in perf_test_generator:
96-
# Actual issue is done during next(generator). Nothing else to do here, just pass.
97-
pass
96+
if self.stop_requested:
97+
self.logger.info(
98+
"Early stop requested, aborting sample issuance"
99+
)
100+
break
98101

99102
EventRecorder.record_event(
100103
SessionEvent.STOP_PERFORMANCE_TRACKING, time.monotonic_ns()
101104
)
102105
self.logger.info("All performance samples issued")
103106

104-
if accuracy_test_generators:
107+
if accuracy_test_generators and not self.stop_requested:
105108
for _, generator in accuracy_test_generators.items():
106109
for _ in generator:
107-
# Actual issue is done during next(generator). Nothing else to do here, just pass.
108-
pass
110+
if self.stop_requested:
111+
break
109112

110113
self.logger.info("All accuracy samples issued")
111114

src/inference_endpoint/metrics/reporter.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,40 @@
3939
if TYPE_CHECKING:
4040
from transformers import Tokenizer
4141

42+
logger = logging.getLogger(__name__)
43+
44+
45+
def _parallel_batch_tokenize(tokenizer: Tokenizer, texts: list[str]) -> list[int]:
46+
"""Batch-tokenize texts using all available cores and return token counts.
47+
48+
Uses a ThreadPoolExecutor to parallelize across ~95% of CPU cores.
49+
HuggingFace tokenizers use a Rust backend that releases the GIL,
50+
so threads achieve real parallelism without GIL contention.
51+
"""
52+
from concurrent.futures import ThreadPoolExecutor
53+
54+
n_cores = os.cpu_count() or 1
55+
n_workers = max(1, int(n_cores * 0.95))
56+
57+
if len(texts) <= n_workers:
58+
# Few texts — just tokenize directly, no threading overhead
59+
encoded = tokenizer(texts, add_special_tokens=False)
60+
return [len(ids) for ids in encoded["input_ids"]]
61+
62+
# Split texts into chunks, one per worker
63+
chunk_size = (len(texts) + n_workers - 1) // n_workers
64+
chunks = [texts[i : i + chunk_size] for i in range(0, len(texts), chunk_size)]
65+
66+
def _tokenize_chunk(chunk: list[str]) -> list[int]:
67+
encoded = tokenizer(chunk, add_special_tokens=False)
68+
return [len(ids) for ids in encoded["input_ids"]]
69+
70+
results: list[int] = []
71+
with ThreadPoolExecutor(max_workers=n_workers) as pool:
72+
for chunk_lengths in pool.map(_tokenize_chunk, chunks):
73+
results.extend(chunk_lengths)
74+
return results
75+
4276

4377
class TPOTReportingMode(str, Enum):
4478
"""TPOT (Time Per Output Token) reporting mode.
@@ -1034,7 +1068,9 @@ def get_output_sequence_lengths(
10341068
"""
10351069
query_result = self.get_sample_outputs()
10361070

1037-
rows = []
1071+
# Collect all texts for batch tokenization
1072+
uuids: list[str] = []
1073+
texts: list[str] = []
10381074
for sample_uuid, data_bytes in query_result:
10391075
output_sequence, reasoning_sequence = output_sequence_from_data(data_bytes)
10401076

@@ -1047,13 +1083,16 @@ def get_output_sequence_lengths(
10471083
else:
10481084
full_sequence = output_sequence
10491085

1050-
# Tokenize and calculate length
1051-
output_tokens = tokenizer.tokenize(full_sequence)
1052-
rows.append((sample_uuid, len(output_tokens)))
1086+
uuids.append(sample_uuid)
1087+
texts.append(full_sequence)
10531088

1054-
if not rows:
1089+
if not texts:
10551090
return None
10561091

1092+
# Parallel batch tokenize across ~95% of cores
1093+
token_counts = _parallel_batch_tokenize(tokenizer, texts)
1094+
rows = list(zip(uuids, token_counts, strict=False))
1095+
10571096
return RollupQueryTable("output_sequence_length", None, rows)
10581097

10591098
@profile
@@ -1103,11 +1142,9 @@ def derive_TPOT(
11031142
if not query_result:
11041143
return None
11051144

1106-
rows = []
1107-
if condense_table and reporting_mode == TPOTReportingMode.TOKEN_WEIGHTED:
1108-
repeats = []
1109-
else:
1110-
repeats = None
1145+
# Pass 1: Collect all non-first-chunk texts for batch tokenization
1146+
batch_uuids: list[str] = []
1147+
batch_texts: list[str] = []
11111148

11121149
for sample_uuid, data_bytes in query_result:
11131150
if data_bytes is None or len(data_bytes) == 0:
@@ -1155,9 +1192,25 @@ def derive_TPOT(
11551192
# Possible malformed output data where empty string is included as a non-first chunk
11561193
continue
11571194

1158-
non_first_tokens = tokenizer.tokenize(non_first_chunk)
1159-
n_non_first_tokens = len(non_first_tokens)
1195+
batch_uuids.append(sample_uuid)
1196+
batch_texts.append(non_first_chunk)
1197+
1198+
if not batch_texts:
1199+
return None
1200+
1201+
# Parallel batch tokenize across ~95% of cores
1202+
token_counts = _parallel_batch_tokenize(tokenizer, batch_texts)
1203+
1204+
# Pass 2: Compute TPOT using batch-tokenized results
1205+
rows = []
1206+
if condense_table and reporting_mode == TPOTReportingMode.TOKEN_WEIGHTED:
1207+
repeats = []
1208+
else:
1209+
repeats = None
11601210

1211+
for sample_uuid, n_non_first_tokens in zip(
1212+
batch_uuids, token_counts, strict=False
1213+
):
11611214
latency = sample_latency_rollup.filter_uuid(sample_uuid, only_first=True)
11621215
if latency is None:
11631216
raise SampleUUIDNotFoundError(sample_uuid, "events record")

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,11 @@ class CharacterTokenizer:
308308
def tokenize(self, text: str) -> list[str]:
309309
return list(text)
310310

311+
def __call__(
312+
self, texts: list[str], **kwargs: object
313+
) -> dict[str, list[list[int]]]:
314+
return {"input_ids": [list(range(len(t))) for t in texts]}
315+
311316

312317
@pytest.fixture
313318
def tokenizer():

0 commit comments

Comments
 (0)