Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ jobs:
name: Build and Publish
runs-on: ubuntu-latest
permissions:
id-token: write
contents: read
steps:
- name: Checkout
Expand Down Expand Up @@ -49,11 +48,23 @@ jobs:
print(f"Version mismatch: pyproject={project_version} tag={tag_version}")
sys.exit(1)

# Also verify version.py is in sync
version_file = Path("src/engram/version.py")
namespace: dict = {}
exec(version_file.read_text(), namespace)
module_version = namespace["__version__"]

if module_version != project_version:
print(f"Version mismatch: version.py={module_version} pyproject={project_version}")
sys.exit(1)

print(f"Version check passed: {project_version}")
PY

- name: Build package
run: uv build --no-sources

- name: Publish to PyPI (OIDC)
uses: pypa/gh-action-pypi-publish@release/v1
- name: Publish to PyPI
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
run: uv publish
31 changes: 29 additions & 2 deletions src/engram/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,41 @@
from ._models import (
CommittedOperation,
CommittedOperations,
Memory,
PreExtractedContent,
RetrievalConfig,
Run,
RunStatus,
SearchResults,
)
from .async_client import AsyncEngramClient
from .client import EngramClient
from .errors import APIError, AuthError, EngramError, ValidationError
from .errors import (
APIError,
AuthenticationError,
ConnectionError,
EngramError,
EngramTimeoutError,
ValidationError,
)
from .version import __version__

__all__ = [
"APIError",
"AsyncEngramClient",
"AuthError",
"AuthenticationError",
"CommittedOperation",
"CommittedOperations",
"ConnectionError",
"EngramClient",
"EngramError",
"EngramTimeoutError",
"Memory",
"PreExtractedContent",
"RetrievalConfig",
"Run",
"RunStatus",
"SearchResults",
"ValidationError",
"__version__",
]
46 changes: 5 additions & 41 deletions src/engram/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import Any

import httpx

from .errors import ValidationError
from .types import ClientConfig
Expand All @@ -14,25 +11,21 @@


class _BaseClient:
"""Shared client behavior for sync and async clients."""

_http_client: httpx.Client | httpx.AsyncClient
_owns_http_client: bool
"""Shared config and header logic for sync and async clients."""

def __init__(
self,
*,
base_url: str = DEFAULT_BASE_URL,
api_key: str | None = None,
api_key: str,
headers: Mapping[str, str] | None = None,
timeout: float = DEFAULT_TIMEOUT,
) -> None:
if timeout <= 0:
raise ValidationError("Timeout must be greater than 0.")

normalized_base_url = base_url.rstrip("/")
header_overrides = headers if headers is not None else {}
default_headers = _build_headers(api_key=api_key, header_overrides=header_overrides)
default_headers = _build_headers(api_key=api_key, header_overrides=headers or {})

self._config = ClientConfig(
base_url=normalized_base_url,
Expand All @@ -49,46 +42,17 @@ def config(self) -> ClientConfig:
def default_headers(self) -> dict[str, str]:
return dict(self._config.headers)

def build_request(
self,
method: str,
path: str,
*,
headers: Mapping[str, str] | None = None,
params: Mapping[str, Any] | None = None,
json: Any | None = None,
) -> httpx.Request:
merged_headers = self.default_headers
if headers:
merged_headers.update(headers)

return self._http_client.build_request(
method=method,
url=_build_url(self._config.base_url, path),
headers=merged_headers,
params=params,
json=json,
)


def _build_headers(
*,
api_key: str | None,
api_key: str,
header_overrides: Mapping[str, str],
) -> dict[str, str]:
headers: dict[str, str] = {
"Accept": "application/json",
"Content-Type": "application/json",
"User-Agent": f"weaviate-engram/{__version__}",
"Authorization": f"Bearer {api_key}",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
headers.update(header_overrides)
return headers


def _build_url(base_url: str, path: str) -> str:
clean_path = path.lstrip("/")
if not clean_path:
return base_url
return f"{base_url}/{clean_path}"
145 changes: 145 additions & 0 deletions src/engram/_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import Any

import httpx

from .errors import APIError, AuthenticationError
from .errors import ConnectionError as EngramConnectionError
from .types import ClientConfig


class HttpTransport:
"""Wraps a sync httpx.Client and handles request building and response processing."""

def __init__(self, config: ClientConfig, http_client: httpx.Client | None = None) -> None:
self._config = config
self._owns_http_client = http_client is None
self._http_client = http_client or httpx.Client(timeout=config.timeout)

def close(self) -> None:
if self._owns_http_client:
self._http_client.close()

def request(
self,
method: str,
path: str,
*,
params: Mapping[str, Any] | None = None,
json: Any | None = None,
) -> dict[str, Any]:
req = self.build_request(method, path, params=params, json=json)
try:
response = self._http_client.send(req)
except httpx.ConnectError as exc:
raise EngramConnectionError(str(exc)) from exc
return _process_response(response)

def build_request(
self,
method: str,
path: str,
*,
headers: Mapping[str, str] | None = None,
params: Mapping[str, Any] | None = None,
json: Any | None = None,
) -> httpx.Request:
merged_headers = dict(self._config.headers)
if headers:
merged_headers.update(headers)
clean_path = path.lstrip("/")
url = f"{self._config.base_url}/{clean_path}" if clean_path else self._config.base_url
return self._http_client.build_request(
method=method,
url=url,
headers=merged_headers,
params=params,
json=json,
)


class AsyncHttpTransport:
"""Wraps an async httpx.AsyncClient and handles request building and response processing."""

def __init__(self, config: ClientConfig, http_client: httpx.AsyncClient | None = None) -> None:
self._config = config
self._owns_http_client = http_client is None
self._http_client = http_client or httpx.AsyncClient(timeout=config.timeout)

async def close(self) -> None:
if self._owns_http_client:
await self._http_client.aclose()

async def request(
self,
method: str,
path: str,
*,
params: Mapping[str, Any] | None = None,
json: Any | None = None,
) -> dict[str, Any]:
req = self.build_request(method, path, params=params, json=json)
try:
response = await self._http_client.send(req)
except httpx.ConnectError as exc:
raise EngramConnectionError(str(exc)) from exc
return _process_response(response)

def build_request(
self,
method: str,
path: str,
*,
headers: Mapping[str, str] | None = None,
params: Mapping[str, Any] | None = None,
json: Any | None = None,
) -> httpx.Request:
merged_headers = dict(self._config.headers)
if headers:
merged_headers.update(headers)
clean_path = path.lstrip("/")
url = f"{self._config.base_url}/{clean_path}" if clean_path else self._config.base_url
return self._http_client.build_request(
method=method,
url=url,
headers=merged_headers,
params=params,
json=json,
)


def _process_response(response: httpx.Response) -> dict[str, Any]:
data = _safe_json(response)

if response.status_code == 401:
detail = _extract_detail(data, "Authentication failed")
raise AuthenticationError(detail, body=data)

if response.status_code >= 400:
detail = _extract_detail(data, response.reason_phrase)
raise APIError(detail, status_code=response.status_code, body=data)

if data is None:
return {}
if isinstance(data, dict):
return data
return {"data": data}


def _extract_detail(data: Any, fallback: str) -> str:
if isinstance(data, dict):
return str(data.get("detail", fallback))
if data:
return str(data)
return fallback


def _safe_json(response: httpx.Response) -> Any:
if not response.content:
return None
try:
return response.json()
except Exception:
return None
14 changes: 14 additions & 0 deletions src/engram/_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .memory import AddContent, Memory, PreExtractedContent, RetrievalConfig, SearchResults
from .run import CommittedOperation, CommittedOperations, Run, RunStatus

__all__ = [
"AddContent",
"CommittedOperation",
"CommittedOperations",
"Memory",
"PreExtractedContent",
"RetrievalConfig",
"Run",
"RunStatus",
"SearchResults",
]
58 changes: 58 additions & 0 deletions src/engram/_models/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field
from typing import Literal, TypeAlias


@dataclass(slots=True)
class PreExtractedContent:
"""Pre-extracted content that bypasses the extraction pipeline."""

content: str
tags: list[str] = field(default_factory=list)


# Type alias for the content argument to memories.add()
AddContent: TypeAlias = str | list[dict[str, str]] | PreExtractedContent


@dataclass(slots=True)
class RetrievalConfig:
retrieval_type: Literal["vector", "bm25", "hybrid"] = "hybrid"
limit: int = 10


@dataclass(slots=True)
class Memory:
id: str
project_id: str
content: str
topic: str
group: str
created_at: str
updated_at: str
user_id: str | None = None
conversation_id: str | None = None
tags: list[str] | None = None
score: float | None = None


class SearchResults(Sequence[Memory]):
"""List-like wrapper over search results with a total count."""

def __init__(self, memories: list[Memory], total: int) -> None:
self._memories = memories
self.total = total

def __getitem__(self, index: int) -> Memory: # type: ignore[override]
return self._memories[index]

def __len__(self) -> int:
return len(self._memories)

def __iter__(self) -> Iterator[Memory]:
return iter(self._memories)

def __repr__(self) -> str:
return f"SearchResults(total={self.total}, returned={len(self._memories)})"
Loading