Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
from langserve.serialization import Serializer

try:
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi import APIRouter, Body, Depends, FastAPI, Request, Response
except ImportError:
# [server] extra not installed
APIRouter = Depends = FastAPI = Request = Response = Any
APIRouter = Body = Depends = FastAPI = Request = Response = Any

# A function that that takes a config and a raw request
# and updates the config based on the request.
Expand Down Expand Up @@ -750,7 +750,7 @@ async def config_schema_with_config(
#######################################
# Documentation variants of end points.
#######################################
# At the moment, we only support pydantic 1.x for documentation
# Prepare models for FastAPI docs (compatible with Pydantic v1 and v2)
InvokeRequest = api_handler.InvokeRequest
InvokeResponse = api_handler.InvokeResponse
BatchRequest = api_handler.BatchRequest
Expand All @@ -759,10 +759,39 @@ async def config_schema_with_config(
StreamLogRequest = api_handler.StreamLogRequest
StreamEventsRequest = api_handler.StreamEventsRequest

# In Pydantic v2, dynamically created models may require model_rebuild()
# to resolve forward refs before being used by FastAPI's TypeAdapter.
def _rebuild_model(m):
try:
rebuild = getattr(m, "model_rebuild", None)
if callable(rebuild):
rebuild(recursive=True)
return
except Exception:
pass
# Fallback for Pydantic v1
try:
upd = getattr(m, "update_forward_refs", None)
if callable(upd):
upd()
except Exception:
pass

for _m in (
InvokeRequest,
InvokeResponse,
BatchRequest,
BatchResponse,
StreamRequest,
StreamLogRequest,
StreamEventsRequest,
):
_rebuild_model(_m)

if endpoint_configuration.is_invoke_enabled:

async def _invoke_docs(
invoke_request: Annotated[InvokeRequest, InvokeRequest],
invoke_request: Annotated[InvokeRequest, Body()],
config_hash: str = "",
) -> InvokeResponse:
"""Invoke the runnable with the given input and config."""
Expand Down Expand Up @@ -795,7 +824,7 @@ async def _invoke_docs(
if endpoint_configuration.is_batch_enabled:

async def _batch_docs(
batch_request: Annotated[BatchRequest, BatchRequest],
batch_request: Annotated[BatchRequest, Body()],
config_hash: str = "",
) -> BatchResponse:
"""Batch invoke the runnable with the given inputs and config."""
Expand Down Expand Up @@ -828,7 +857,7 @@ async def _batch_docs(
if endpoint_configuration.is_stream_enabled:

async def _stream_docs(
stream_request: Annotated[StreamRequest, StreamRequest],
stream_request: Annotated[StreamRequest, Body()],
config_hash: str = "",
) -> EventSourceResponse:
"""Invoke the runnable stream the output.
Expand Down Expand Up @@ -912,7 +941,7 @@ async def _stream_docs(
if endpoint_configuration.is_stream_log_enabled:

async def _stream_log_docs(
stream_log_request: Annotated[StreamLogRequest, StreamLogRequest],
stream_log_request: Annotated[StreamLogRequest, Body()],
config_hash: str = "",
) -> EventSourceResponse:
"""Invoke the runnable stream_log the output.
Expand Down Expand Up @@ -986,7 +1015,7 @@ async def _stream_log_docs(
if has_astream_events and endpoint_configuration.is_stream_events_enabled:

async def _stream_events_docs(
stream_events_request: Annotated[StreamEventsRequest, StreamEventsRequest],
stream_events_request: Annotated[StreamEventsRequest, Body()],
config_hash: str = "",
) -> EventSourceResponse:
"""Stream events from the given runnable.
Expand Down