diff --git a/src/aci/infrastructure/embedding/client.py b/src/aci/infrastructure/embedding/client.py index 72f1778..70dedc7 100644 --- a/src/aci/infrastructure/embedding/client.py +++ b/src/aci/infrastructure/embedding/client.py @@ -29,8 +29,9 @@ class OpenAIEmbeddingClient(EmbeddingClientInterface): errors (HTTP 413) by reducing the batch size and retrying. This allows successful embedding generation even when some batches exceed the API's token limit. The batch size is halved on each retry until it reaches - the configured minimum. If a single item exceeds the limit, a - NonRetryableError is raised. + the configured minimum. If a single item still exceeds the limit at + minimum batch size, the item is skipped and a zero vector placeholder + is inserted to preserve output ordering. """ def __init__( @@ -115,7 +116,7 @@ async def embed_batch(self, texts: list[str]) -> list[list[float]]: Raises: EmbeddingClientError: If embedding generation fails after retries - NonRetryableError: If a single item exceeds token limits + NonRetryableError: If embedding generation encounters a non-recoverable error """ if not texts: return [] @@ -139,7 +140,7 @@ async def _embed_with_fallback( List of embedding vectors in the same order as input texts Raises: - NonRetryableError: If a single item exceeds token limits + NonRetryableError: If embedding fails due to non-recoverable API errors EmbeddingClientError: If embedding fails after all retries """ all_embeddings: list[list[float]] = [] @@ -162,14 +163,14 @@ async def _embed_with_fallback( # Check if we can reduce batch size further if current_batch_size <= config.min_batch_size: - # Single item exceeds token limit - logger.error( + # Single item exceeds token limit even at minimum batch size + logger.warning( f"Item at index {i} exceeds token limit, " - f"cannot reduce batch further (min_batch_size={config.min_batch_size})" + f"skipping with zero vector (min_batch_size={config.min_batch_size})" ) - raise NonRetryableError( - f"Single item at index {i} exceeds token limit: {e}" - ) from e + all_embeddings.append([0.0] * self._dimension) + i += 1 + continue # Reduce batch size and retry new_batch_size = max(config.min_batch_size, current_batch_size // 2) diff --git a/tests/property/test_embedding_client_properties.py b/tests/property/test_embedding_client_properties.py index 4b94e34..0c60ca1 100644 --- a/tests/property/test_embedding_client_properties.py +++ b/tests/property/test_embedding_client_properties.py @@ -428,3 +428,74 @@ async def run_test(): assert call_count == 1, ( f"Expected 1 API call when fallback disabled, got {call_count}" ) + + +@given(texts_count=st.integers(min_value=2, max_value=20)) +@settings(max_examples=50, deadline=None) +def test_oversized_single_item_is_skipped_with_zero_vector(texts_count: int): + """ + **Feature: embedding-batch-fallback, Property 5: Oversized Item Isolation** + **Validates: Requirements 1.3, 4.3** + + *For any* input list containing one permanently oversized item, + the client SHALL continue processing remaining items and return a + zero-vector placeholder at the oversized item's position. + """ + + texts = [f"text_{i}" for i in range(texts_count)] + oversized_index = texts_count // 2 + + async def mock_post(url, headers, json): + """Mock HTTP POST that fails only for a specific oversized item.""" + batch_texts = json.get("input", []) + + # Simulate token limit failure only for the specific oversized item + if len(batch_texts) == 1 and batch_texts[0] == texts[oversized_index]: + mock_response = MagicMock() + mock_response.status_code = 413 + mock_response.text = "Token limit exceeded" + return mock_response + + # Force fallback into single-item processing if oversized item is in a larger batch + if texts[oversized_index] in batch_texts: + mock_response = MagicMock() + mock_response.status_code = 413 + mock_response.text = "Token limit exceeded" + return mock_response + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [{"index": i, "embedding": [0.1] * 1536} for i in range(len(batch_texts))] + } + return mock_response + + client = OpenAIEmbeddingClient( + api_url="https://api.example.com/embeddings", + api_key="test-key", + batch_size=8, + retry_config=RetryConfig( + max_retries=0, + enable_batch_fallback=True, + min_batch_size=1, + ), + ) + + async def run_test(): + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + return await client.embed_batch(texts) + + embeddings = asyncio.run(run_test()) + + assert len(embeddings) == len(texts), ( + f"Expected {len(texts)} embeddings, got {len(embeddings)}" + ) + assert embeddings[oversized_index] == [0.0] * 1536, ( + "Oversized item should be replaced with a zero vector placeholder" + )