-
Notifications
You must be signed in to change notification settings - Fork 279
Expand file tree
/
Copy patheval_token_usage.py
More file actions
142 lines (114 loc) · 4.48 KB
/
eval_token_usage.py
File metadata and controls
142 lines (114 loc) · 4.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python3
"""
Evaluate Mirix token usage by uploading the sample images from set1.
This mirrors the add-image logic from run_client.py but focuses solely on
feeding the .local/images/set1 assets into memory so we can inspect usage
statistics returned by the API.
"""
import asyncio
import base64
import logging
import mimetypes
from io import BytesIO
from pathlib import Path
from typing import Iterable, List, Tuple
from PIL import Image, ImageOps
from mirix import MirixClient
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
SUPPORTED_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp"}
IMAGE_SET_PATH = Path(".local/images/set1")
MAX_IMAGE_DIMENSION = 512
def _save_resized_image(image_path: Path) -> Tuple[bytes, str]:
"""Return the image bytes (resized if needed) and its mime type."""
mime_type = mimetypes.guess_type(image_path.name)[0] or "image/png"
with Image.open(image_path) as img:
img = ImageOps.exif_transpose(img)
if max(img.size) > MAX_IMAGE_DIMENSION:
img.thumbnail((MAX_IMAGE_DIMENSION, MAX_IMAGE_DIMENSION), Image.Resampling.LANCZOS)
buffer = BytesIO()
save_format = (img.format or mime_type.split("/")[-1]).upper()
if save_format == "JPG":
save_format = "JPEG"
try:
img.save(buffer, format=save_format)
except ValueError:
# Some formats do not support saving; fall back to PNG.
buffer = BytesIO()
img.save(buffer, format="PNG")
mime_type = "image/png"
return buffer.getvalue(), mime_type
def encode_image_to_data_url(image_path: Path) -> str:
"""Return a data URL string for an image on disk."""
image_bytes, mime_type = _save_resized_image(image_path)
encoded = base64.b64encode(image_bytes).decode("utf-8")
return f"data:{mime_type};base64,{encoded}"
def get_images_from_set(set_path: Path) -> List[Path]:
"""Collect all supported images from the provided directory."""
if not set_path.exists():
raise FileNotFoundError(f"Image directory does not exist: {set_path}")
image_paths = sorted(
path for path in set_path.iterdir() if path.is_file() and path.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS
)
if not image_paths:
raise RuntimeError(f"No supported images found in {set_path}")
logger.info("Found %d images in %s", len(image_paths), set_path)
return image_paths
async def add_images_to_memory(client: MirixClient, user_id: str, image_paths: Iterable[Path]) -> None:
"""Send each image to the memory-add endpoint."""
for image_path in image_paths:
data_url = encode_image_to_data_url(image_path)
logger.info("Adding %s ...", image_path.name)
response = await client.add(
user_id=user_id,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": ("Please store this reference image for later: " f"{image_path.name}"),
},
{
"type": "image_data",
"image_data": {
"data": data_url,
"detail": "high",
},
},
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": f"I've received {image_path.name}.",
}
],
},
],
chaining=False,
)
success = response.get("success")
usage = response.get("usage")
logger.info("Response success=%s usage=%s", success, usage)
async def main() -> None:
user_id = "token-usage-user"
org_id = "token-usage-org"
client = await MirixClient.create(
org_id=org_id,
api_key=None,
debug=True,
)
await client.initialize_meta_agent(
config_path="mirix/configs/examples/mirix_gemini.yaml",
update_agents=False,
)
images = get_images_from_set(IMAGE_SET_PATH)
await add_images_to_memory(client, user_id=user_id, image_paths=images)
if __name__ == "__main__":
asyncio.run(main())