Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,7 @@ jobs:
pip install "uvloop==0.21.0"
pip install "fastuuid==0.12.0"
pip install jsonschema
pip install "orjson==3.10.12"
- setup_litellm_enterprise_pip
- run:
name: Run tests
Expand Down
4 changes: 2 additions & 2 deletions litellm/litellm_core_utils/safe_json_dumps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
import orjson
from typing import Any, Union

from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
Expand Down Expand Up @@ -49,4 +49,4 @@ def _serialize(obj: Any, seen: set, depth: int) -> Any:
return "Unserializable Object"

safe_data = _serialize(data, set(), 0)
return json.dumps(safe_data, default=str)
return orjson.dumps(safe_data, default=str).decode("utf-8")
4 changes: 2 additions & 2 deletions litellm/litellm_core_utils/safe_json_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
Helper for safe JSON loading in LiteLLM.
"""
from typing import Any
import json
import orjson

def safe_json_loads(data: str, default: Any = None) -> Any:
"""
Safely parse a JSON string. If parsing fails, return the default value (None by default).
"""
try:
return json.loads(data)
return orjson.loads(data)
except Exception:
return default
11 changes: 6 additions & 5 deletions litellm/litellm_core_utils/streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections.abc
import datetime
import json
import orjson
import threading
import time
import traceback
Expand Down Expand Up @@ -267,7 +268,7 @@ def handle_predibase_chunk(self, chunk):
finish_reason = ""
print_verbose(f"chunk: {chunk}")
if chunk.startswith("data:"):
data_json = json.loads(chunk[5:])
data_json = orjson.loads(chunk[5:])
print_verbose(f"data json: {data_json}")
if "token" in data_json and "text" in data_json["token"]:
text = data_json["token"]["text"]
Expand Down Expand Up @@ -301,7 +302,7 @@ def handle_predibase_chunk(self, chunk):

def handle_ai21_chunk(self, chunk): # fake streaming
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
data_json = orjson.loads(chunk)
try:
text = data_json["completions"][0]["data"]["text"]
is_finished = True
Expand All @@ -316,7 +317,7 @@ def handle_ai21_chunk(self, chunk): # fake streaming

def handle_maritalk_chunk(self, chunk): # fake streaming
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
data_json = orjson.loads(chunk)
try:
text = data_json["answer"]
is_finished = True
Expand All @@ -337,7 +338,7 @@ def handle_nlp_cloud_chunk(self, chunk):
if self.model and "dolphin" in self.model:
chunk = self.process_chunk(chunk=chunk)
else:
data_json = json.loads(chunk)
data_json = orjson.loads(chunk)
chunk = data_json["generated_text"]
text = chunk
if "[DONE]" in text:
Expand All @@ -354,7 +355,7 @@ def handle_nlp_cloud_chunk(self, chunk):

def handle_aleph_alpha_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
data_json = orjson.loads(chunk)
try:
text = data_json["completions"][0]["completion"]
is_finished = True
Expand Down
21 changes: 11 additions & 10 deletions litellm/llms/custom_httpx/llm_http_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import orjson
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -133,7 +134,7 @@ async def _make_common_async_call(
data=(
signed_json_body
if signed_json_body is not None
else json.dumps(data)
else orjson.dumps(data)
),
timeout=timeout,
stream=stream,
Expand Down Expand Up @@ -193,7 +194,7 @@ def _make_common_sync_call(
data=(
signed_json_body
if signed_json_body is not None
else json.dumps(data)
else orjson.dumps(data)
),
timeout=timeout,
stream=stream,
Expand Down Expand Up @@ -831,7 +832,7 @@ def embedding(
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
data=orjson.dumps(data),
timeout=timeout,
)
except Exception as e:
Expand Down Expand Up @@ -963,7 +964,7 @@ def rerank(
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
data=orjson.dumps(data),
timeout=timeout,
)
except Exception as e:
Expand Down Expand Up @@ -1005,7 +1006,7 @@ async def arerank(
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(request_data),
data=orjson.dumps(request_data),
timeout=timeout,
)
except Exception as e:
Expand Down Expand Up @@ -1852,7 +1853,7 @@ async def async_anthropic_messages_handler(
response = await async_httpx_client.post(
url=request_url,
headers=headers,
data=signed_json_body or json.dumps(request_body),
data=signed_json_body or orjson.dumps(request_body),
stream=stream or False,
logging_obj=logging_obj,
)
Expand Down Expand Up @@ -2756,7 +2757,7 @@ def create_file(
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
data=orjson.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)

Expand Down Expand Up @@ -2867,7 +2868,7 @@ async def async_create_file(
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
data=orjson.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)

Expand Down Expand Up @@ -4970,7 +4971,7 @@ async def async_vector_store_search_handler(
)

request_data = (
json.dumps(request_body) if signed_json_body is None else signed_json_body
orjson.dumps(request_body) if signed_json_body is None else signed_json_body
)

try:
Expand Down Expand Up @@ -5073,7 +5074,7 @@ def vector_store_search_handler(
)

request_data = (
json.dumps(request_body) if signed_json_body is None else signed_json_body
orjson.dumps(request_body) if signed_json_body is None else signed_json_body
)

try:
Expand Down
2 changes: 1 addition & 1 deletion litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import logging
import traceback
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -48,6 +47,7 @@
ProxyConfig = Any
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.types.utils import ModelResponse, ModelResponseStream, Usage
import datetime


async def _parse_event_data_for_error(event_line: Union[str, bytes]) -> Optional[int]:
Expand Down
9 changes: 5 additions & 4 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import inspect
import io
import orjson
import os
import random
import secrets
Expand Down Expand Up @@ -3897,7 +3898,7 @@ async def async_assistants_data_generator(

# chunk = chunk.model_dump_json(exclude_none=True)
async for c in chunk: # type: ignore
c = c.model_dump_json(exclude_none=True)
c = orjson.dumps(c.model_dump(exclude_none=True)).decode("utf-8")
try:
yield f"data: {c}\n\n"
except Exception as e:
Expand Down Expand Up @@ -3932,7 +3933,7 @@ async def async_assistants_data_generator(
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
error_returned = orjson.dumps({"error": proxy_exception.to_dict()}).decode("utf-8")
yield f"data: {error_returned}\n\n"


Expand Down Expand Up @@ -3965,7 +3966,7 @@ async def async_data_generator(
str_so_far += response_str

if isinstance(chunk, BaseModel):
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
chunk = orjson.dumps(chunk.model_dump(exclude_none=True, exclude_unset=True)).decode("utf-8")
elif isinstance(chunk, str) and chunk.startswith("data: "):
error_message = chunk
break
Expand Down Expand Up @@ -4009,7 +4010,7 @@ async def async_data_generator(
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
error_returned = orjson.dumps({"error": proxy_exception.to_dict()}).decode("utf-8")
yield f"data: {error_returned}\n\n"


Expand Down
7 changes: 4 additions & 3 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import hashlib
import json
import orjson
import os
import smtplib
import threading
Expand Down Expand Up @@ -1728,7 +1729,7 @@ def get_request_status(
)
if isinstance(payload_metadata, str):
payload_metadata_json: Union[Dict, SpendLogsMetadata] = cast(
Dict, json.loads(payload_metadata)
Dict, orjson.loads(payload_metadata)
)
else:
payload_metadata_json = payload_metadata
Expand All @@ -1740,7 +1741,7 @@ def get_request_status(
else "success"
)

except (json.JSONDecodeError, AttributeError):
except (orjson.JSONDecodeError, AttributeError):
# Default to success if metadata parsing fails
return "success"

Expand All @@ -1756,7 +1757,7 @@ def jsonify_object(self, data: dict) -> dict:
for k, v in db_data.items():
if isinstance(v, dict):
try:
db_data[k] = json.dumps(v)
db_data[k] = orjson.dumps(v).decode("utf-8")
except Exception:
# This avoids Prisma retrying this 5 times, and making 5 clients
db_data[k] = "failed-to-serialize-json"
Expand Down
10 changes: 6 additions & 4 deletions litellm/router_utils/cooldown_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Wrapper around router cache. Meant to handle model cooldown logic
"""

import functools
import time
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union

Expand Down Expand Up @@ -44,7 +45,7 @@ def _common_add_cooldown_logic(
) -> Tuple[str, CooldownCacheValue]:
try:
current_time = time.time()
cooldown_key = f"deployment:{model_id}:cooldown"
cooldown_key = CooldownCache.get_cooldown_cache_key(model_id)

# Store the cooldown information for the deployment separately
cooldown_data = CooldownCacheValue(
Expand Down Expand Up @@ -104,8 +105,9 @@ def add_deployment_to_cooldown(
raise e

@staticmethod
@functools.lru_cache(maxsize=1024)
def get_cooldown_cache_key(model_id: str) -> str:
return f"deployment:{model_id}:cooldown"
return "deployment:" + model_id + ":cooldown"

async def async_get_active_cooldowns(
self, model_ids: List[str], parent_otel_span: Optional[Span]
Expand Down Expand Up @@ -140,7 +142,7 @@ def get_active_cooldowns(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
keys = [CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids]
# Retrieve the values for the keys using mget
results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
Expand All @@ -162,7 +164,7 @@ def get_min_cooldown(
"""Return min cooldown time required for a group of model id's."""

# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
keys = [CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids]

# Retrieve the values for the keys using mget
results = (
Expand Down
14 changes: 5 additions & 9 deletions litellm/router_utils/prompt_caching_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import hashlib
import json
import orjson
from typing import TYPE_CHECKING, Any, List, Optional, Union

from typing_extensions import TypedDict
Expand Down Expand Up @@ -41,8 +41,8 @@ def serialize_object(obj: Any) -> Any:
return obj.dict()
elif isinstance(obj, dict):
# If the object is a dictionary, serialize it with sorted keys
return json.dumps(
obj, sort_keys=True, separators=(",", ":")
return orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode(
"utf-8"
) # Standardize serialization

elif isinstance(obj, list):
Expand All @@ -69,14 +69,10 @@ def get_prompt_caching_cache_key(
data_to_hash["tools"] = serialized_tools

# Combine serialized data into a single string
data_to_hash_str = json.dumps(
data_to_hash,
sort_keys=True,
separators=(",", ":"),
)
data_to_hash_bytes = orjson.dumps(data_to_hash, option=orjson.OPT_SORT_KEYS)

# Create a hash of the serialized data for a stable cache key
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
hashed_data = hashlib.sha256(data_to_hash_bytes).hexdigest()
return f"deployment:{hashed_data}:prompt_caching"

def add_model_id(
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jinja2 = "^3.1.2"
aiohttp = ">=3.10"
pydantic = "^2.5.0"
jsonschema = "^4.22.0"
orjson = "^3.9.7"
numpydoc = {version = "*", optional = true} # used in utils.py

uvicorn = {version = "^0.29.0", optional = true}
Expand All @@ -41,7 +42,6 @@ fastapi = {version = ">=0.120.1", optional = true}
backoff = {version = "*", optional = true}
pyyaml = {version = "^6.0.1", optional = true}
rq = {version = "*", optional = true}
orjson = {version = "^3.9.7", optional = true}
apscheduler = {version = "^3.10.4", optional = true}
fastapi-sso = { version = "^0.16.0", optional = true }
PyJWT = { version = "^2.8.0", optional = true }
Expand Down Expand Up @@ -76,7 +76,6 @@ proxy = [
"backoff",
"pyyaml",
"rq",
"orjson",
"apscheduler",
"fastapi-sso",
"PyJWT",
Expand Down
Loading
Loading