Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ packages = ["tinfoil", "tinfoil.attestation"]

[project]
name = "tinfoil"
version = "0.4.1"
version = "0.10.0"
description = "Python client for Tinfoil"
readme = "README.md"
requires-python = ">=3.10"
Expand All @@ -20,7 +20,7 @@ dependencies = [
"requests>=2.31.0",
"cryptography>=42.0.0",
"pyOpenSSL>=25.0.0",
"sigstore>=3.6.2",
"sigstore>=4.1.0",
"platformdirs>=4.2.0",
"pytest-asyncio>=0.26.0"
]
Expand Down
25 changes: 16 additions & 9 deletions src/tinfoil/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import hashlib
import ssl
import cryptography.x509
from cryptography.hazmat.primitives.serialization import PublicFormat, Encoding

from openai import OpenAI, AsyncOpenAI
from openai.resources.chat import Chat as OpenAIChat
from openai.resources.embeddings import Embeddings as OpenAIEmbeddings
from openai.resources.audio import Audio as OpenAIAudio

from .client import SecureClient
from .client import SecureClient, get_router_address

class TinfoilAI:
chat: OpenAIChat
Expand All @@ -17,14 +12,18 @@ class TinfoilAI:
api_key: str
enclave: str

def __init__(self, enclave: str = "inference.tinfoil.sh", repo: str = "tinfoilsh/confidential-inference-proxy", api_key: str = "tinfoil", measurement: dict = None):
def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model-router", api_key: str = "tinfoil", measurement: dict = None):
if measurement is not None:
repo = ""

# Ensure at least one verification method is provided
if measurement is None and (repo == "" or repo is None):
raise ValueError("Must provide either 'measurement' or 'repo' parameter for verification.")

# If enclave is empty, fetch a random one from the routers API
if enclave == "" or enclave is None:
enclave = get_router_address()

self.enclave = enclave
self.api_key = api_key
tf_client = SecureClient(enclave, repo, measurement)
Expand All @@ -48,14 +47,18 @@ class AsyncTinfoilAI:
api_key: str
enclave: str

def __init__(self, enclave: str = "inference.tinfoil.sh", repo: str = "tinfoilsh/confidential-inference-proxy", api_key: str = "tinfoil", measurement: dict = None):
def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model-router", api_key: str = "tinfoil", measurement: dict = None):
if measurement is not None:
repo = ""

# Ensure at least one verification method is provided
if measurement is None and (repo == "" or repo is None):
raise ValueError("Must provide either 'measurement' or 'repo' parameter for verification.")

# If enclave is empty, fetch a random one from the routers API
if enclave == "" or enclave is None:
enclave = get_router_address()

self.enclave = enclave
self.api_key = api_key
# verifier client remains sync; only used to fetch the expected public key
Expand Down Expand Up @@ -92,7 +95,7 @@ def post(
return self._http_client.post(url, headers=headers, data=data, json=json, timeout=timeout)


def NewSecureClient(enclave: str = "inference.tinfoil.sh", repo: str = "tinfoilsh/confidential-inference-proxy", api_key: str = "tinfoil", measurement: dict = None):
def NewSecureClient(enclave: str = "", repo: str = "tinfoilsh/confidential-model-router", api_key: str = "tinfoil", measurement: dict = None):
"""
Create a secure HTTP client for direct GET/POST through the Tinfoil enclave.
"""
Expand All @@ -103,6 +106,10 @@ def NewSecureClient(enclave: str = "inference.tinfoil.sh", repo: str = "tinfoils
if measurement is None and (repo == "" or repo is None):
raise ValueError("Must provide either 'measurement' or 'repo' parameter for verification.")

# If enclave is empty, fetch a random one from the routers API
if enclave == "" or enclave is None:
enclave = get_router_address()

tf_client = SecureClient(enclave, repo, measurement)
return _HTTPSecureClient(enclave, tf_client, api_key)

Expand Down
22 changes: 21 additions & 1 deletion src/tinfoil/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ssl
import urllib.request
import httpx
import random
from dataclasses import dataclass
from typing import Dict, Optional
from urllib.parse import urlparse
Expand Down Expand Up @@ -70,7 +71,7 @@ def _get_connection(self, host, timeout=None):
class SecureClient:
"""A client that verifies and communicates with secure enclaves"""

def __init__(self, enclave: str = "inference.tinfoil.sh", repo: str = "tinfoilsh/confidential-inference-proxy", measurement: Optional[dict] = None):
def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model-router", measurement: Optional[dict] = None):
# Hardcoded measurement takes precedence over repo
if measurement is not None:
repo = ""
Expand All @@ -79,6 +80,10 @@ def __init__(self, enclave: str = "inference.tinfoil.sh", repo: str = "tinfoilsh
if measurement is None and (repo == "" or repo is None):
raise ValueError("Must provide either 'measurement' or 'repo' parameter for verification.")

# If enclave is empty, fetch a random one from the routers API
if enclave == "" or enclave is None:
enclave = get_router_address()

self.enclave = enclave
self.repo = repo
self.measurement = measurement
Expand Down Expand Up @@ -225,3 +230,18 @@ def get(self, url: str, headers: Dict[str, str] = {}) -> Response:
method="GET"
)
return self.make_request(req)

def get_router_address() -> str:
"""
Fetches the list of available routers from the ATC API
and returns a randomly selected address.
"""

routers_url = "https://atc.tinfoil.sh/routers?platform=snp"

with urllib.request.urlopen(routers_url) as response:
routers = json.loads(response.read().decode('utf-8'))
if len(routers) == 0:
raise ValueError("No routers found in the response")

return random.choice(routers)
8 changes: 2 additions & 6 deletions src/tinfoil/sigstore.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import sigstore
from sigstore.verify import Verifier
from sigstore.verify.policy import AllOf, OIDCIssuer, GitHubWorkflowRepository, GitHubWorkflowRef, GitHubWorkflowSHA, Certificate, _OIDC_GITHUB_WORKFLOW_REF_OID, ExtensionNotFound, OIDCSourceRepositoryDigest, OIDCBuildSignerDigest, OIDCBuildConfigDigest
from sigstore.verify.policy import AllOf, OIDCIssuer, GitHubWorkflowRepository, Certificate, _OIDC_GITHUB_WORKFLOW_REF_OID, ExtensionNotFound
from sigstore.models import Bundle
from sigstore.errors import VerificationError
from sigstore._utils import sha256_digest
import json
import re
import binascii

from .attestation import Measurement, PredicateType

Expand Down Expand Up @@ -97,12 +94,11 @@ def verify_attestation(bundle_json: bytes, digest: str, repo: str) -> Measuremen
registers = [predicate_fields["snp_measurement"]]
else:
raise ValueError(f"Unsupported predicate type: {predicate_type}")

return Measurement(
type=predicate_type,
registers=registers
)

except Exception as e:
raise ValueError(f"Attestation processing failed: {e}") from e

17 changes: 11 additions & 6 deletions tests/test_attestation_flow.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import os
import pytest
import sys

# Adjust these imports based on your project structure
from tinfoil.github import fetch_latest_digest, fetch_attestation_bundle
from tinfoil.sigstore import verify_attestation
from tinfoil.attestation import fetch_attestation
from tinfoil.client import get_router_address

pytestmark = pytest.mark.integration # allows pytest -m integration filtering

# Fetch config from environment variables, falling back to defaults
# Use the same env vars as the other integration test for consistency
ENCLAVE = "inference.tinfoil.sh"
REPO = "tinfoilsh/confidential-inference-proxy"
REPO = "tinfoilsh/confidential-model-router"

def test_full_verification_flow():
"""
Expand All @@ -25,6 +23,13 @@ def test_full_verification_flow():
6. Compare code measurements with runtime measurements.
"""
try:
# Fetch enclave address lazily inside the test to avoid import-time network calls
try:
enclave = get_router_address()
except Exception as e:
pytest.skip(f"Could not fetch router address from ATC service: {e}")
return

# Fetch latest release digest
print(f"Fetching latest release for {REPO}")
digest = fetch_latest_digest(REPO)
Expand All @@ -47,8 +52,8 @@ def test_full_verification_flow():


# Fetch runtime attestation from the enclave
print(f"Fetching runtime attestation from {ENCLAVE}")
enclave_attestation = fetch_attestation(ENCLAVE)
print(f"Fetching runtime attestation from {enclave}")
enclave_attestation = fetch_attestation(enclave)
assert enclave_attestation is not None # Basic check

# Verify enclave measurements from runtime attestation
Expand Down
10 changes: 6 additions & 4 deletions tests/test_integration_async.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# tests/test_integration_async.py
import os
import pytest
from tinfoil import AsyncTinfoilAI
from tinfoil import AsyncTinfoilAI, get_router_address

pytestmark = pytest.mark.integration

ENCLAVE = "inference.tinfoil.sh"
REPO = "tinfoilsh/confidential-inference-proxy"
REPO = "tinfoilsh/confidential-model-router"
API_KEY = os.getenv("TINFOIL_API_KEY", "tinfoil")

@pytest.mark.asyncio
async def test_async_chat_integration():
enclave = get_router_address()
assert enclave is not None

client = AsyncTinfoilAI(
enclave=ENCLAVE,
enclave=enclave,
repo=REPO,
api_key=API_KEY,
)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_integration_http.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import os
import pytest
from tinfoil import NewSecureClient
from tinfoil import NewSecureClient, get_router_address

ENCLAVE = "inference.tinfoil.sh"
REPO = "tinfoilsh/confidential-inference-proxy"
REPO = "tinfoilsh/confidential-model-router"
API_KEY = os.getenv("TINFOIL_API_KEY", "tinfoil")

pytestmark = pytest.mark.integration

def test_http_integration():
client = NewSecureClient(ENCLAVE, REPO, api_key=API_KEY)
enclave = get_router_address()
client = NewSecureClient(enclave, REPO, api_key=API_KEY)

url = f"https://{ENCLAVE}/v1/chat/completions"
url = f"https://{enclave}/v1/chat/completions"
headers = {"Authorization": f"Bearer {API_KEY}"}
payload = {
"model": "llama3-3-70b",
Expand Down