Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/aci/infrastructure/embedding/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ class OpenAIEmbeddingClient(EmbeddingClientInterface):
errors (HTTP 413) by reducing the batch size and retrying. This allows
successful embedding generation even when some batches exceed the API's
token limit. The batch size is halved on each retry until it reaches
the configured minimum. If a single item exceeds the limit, a
NonRetryableError is raised.
the configured minimum. If a single item still exceeds the limit at
minimum batch size, the item is skipped and a zero vector placeholder
is inserted to preserve output ordering.
"""

def __init__(
Expand Down Expand Up @@ -115,7 +116,7 @@ async def embed_batch(self, texts: list[str]) -> list[list[float]]:

Raises:
EmbeddingClientError: If embedding generation fails after retries
NonRetryableError: If a single item exceeds token limits
NonRetryableError: If embedding generation encounters a non-recoverable error
"""
if not texts:
return []
Expand All @@ -139,7 +140,7 @@ async def _embed_with_fallback(
List of embedding vectors in the same order as input texts

Raises:
NonRetryableError: If a single item exceeds token limits
NonRetryableError: If embedding fails due to non-recoverable API errors
EmbeddingClientError: If embedding fails after all retries
"""
all_embeddings: list[list[float]] = []
Expand All @@ -162,14 +163,14 @@ async def _embed_with_fallback(

# Check if we can reduce batch size further
if current_batch_size <= config.min_batch_size:
# Single item exceeds token limit
logger.error(
# Single item exceeds token limit even at minimum batch size
logger.warning(
f"Item at index {i} exceeds token limit, "
f"cannot reduce batch further (min_batch_size={config.min_batch_size})"
f"skipping with zero vector (min_batch_size={config.min_batch_size})"
)
raise NonRetryableError(
f"Single item at index {i} exceeds token limit: {e}"
) from e
all_embeddings.append([0.0] * self._dimension)
i += 1
continue

# Reduce batch size and retry
new_batch_size = max(config.min_batch_size, current_batch_size // 2)
Expand Down
71 changes: 71 additions & 0 deletions tests/property/test_embedding_client_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,74 @@ async def run_test():
assert call_count == 1, (
f"Expected 1 API call when fallback disabled, got {call_count}"
)


@given(texts_count=st.integers(min_value=2, max_value=20))
@settings(max_examples=50, deadline=None)
def test_oversized_single_item_is_skipped_with_zero_vector(texts_count: int):
"""
**Feature: embedding-batch-fallback, Property 5: Oversized Item Isolation**
**Validates: Requirements 1.3, 4.3**

*For any* input list containing one permanently oversized item,
the client SHALL continue processing remaining items and return a
zero-vector placeholder at the oversized item's position.
"""

texts = [f"text_{i}" for i in range(texts_count)]
oversized_index = texts_count // 2

async def mock_post(url, headers, json):
"""Mock HTTP POST that fails only for a specific oversized item."""
batch_texts = json.get("input", [])

# Simulate token limit failure only for the specific oversized item
if len(batch_texts) == 1 and batch_texts[0] == texts[oversized_index]:
mock_response = MagicMock()
mock_response.status_code = 413
mock_response.text = "Token limit exceeded"
return mock_response

# Force fallback into single-item processing if oversized item is in a larger batch
if texts[oversized_index] in batch_texts:
mock_response = MagicMock()
mock_response.status_code = 413
mock_response.text = "Token limit exceeded"
return mock_response

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [{"index": i, "embedding": [0.1] * 1536} for i in range(len(batch_texts))]
}
return mock_response

client = OpenAIEmbeddingClient(
api_url="https://api.example.com/embeddings",
api_key="test-key",
batch_size=8,
retry_config=RetryConfig(
max_retries=0,
enable_batch_fallback=True,
min_batch_size=1,
),
)

async def run_test():
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.post = mock_post
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client

return await client.embed_batch(texts)

embeddings = asyncio.run(run_test())

assert len(embeddings) == len(texts), (
f"Expected {len(texts)} embeddings, got {len(embeddings)}"
)
assert embeddings[oversized_index] == [0.0] * 1536, (
"Oversized item should be replaced with a zero vector placeholder"
)