diff --git a/langserve/server.py b/langserve/server.py index 4602a1d3..488a081a 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -29,10 +29,10 @@ from langserve.serialization import Serializer try: - from fastapi import APIRouter, Depends, FastAPI, Request, Response + from fastapi import APIRouter, Body, Depends, FastAPI, Request, Response except ImportError: # [server] extra not installed - APIRouter = Depends = FastAPI = Request = Response = Any + APIRouter = Body = Depends = FastAPI = Request = Response = Any # A function that that takes a config and a raw request # and updates the config based on the request. @@ -750,7 +750,7 @@ async def config_schema_with_config( ####################################### # Documentation variants of end points. ####################################### - # At the moment, we only support pydantic 1.x for documentation + # Prepare models for FastAPI docs (compatible with Pydantic v1 and v2) InvokeRequest = api_handler.InvokeRequest InvokeResponse = api_handler.InvokeResponse BatchRequest = api_handler.BatchRequest @@ -759,10 +759,39 @@ async def config_schema_with_config( StreamLogRequest = api_handler.StreamLogRequest StreamEventsRequest = api_handler.StreamEventsRequest + # In Pydantic v2, dynamically created models may require model_rebuild() + # to resolve forward refs before being used by FastAPI's TypeAdapter. + def _rebuild_model(m): + try: + rebuild = getattr(m, "model_rebuild", None) + if callable(rebuild): + rebuild(recursive=True) + return + except Exception: + pass + # Fallback for Pydantic v1 + try: + upd = getattr(m, "update_forward_refs", None) + if callable(upd): + upd() + except Exception: + pass + + for _m in ( + InvokeRequest, + InvokeResponse, + BatchRequest, + BatchResponse, + StreamRequest, + StreamLogRequest, + StreamEventsRequest, + ): + _rebuild_model(_m) + if endpoint_configuration.is_invoke_enabled: async def _invoke_docs( - invoke_request: Annotated[InvokeRequest, InvokeRequest], + invoke_request: Annotated[InvokeRequest, Body()], config_hash: str = "", ) -> InvokeResponse: """Invoke the runnable with the given input and config.""" @@ -795,7 +824,7 @@ async def _invoke_docs( if endpoint_configuration.is_batch_enabled: async def _batch_docs( - batch_request: Annotated[BatchRequest, BatchRequest], + batch_request: Annotated[BatchRequest, Body()], config_hash: str = "", ) -> BatchResponse: """Batch invoke the runnable with the given inputs and config.""" @@ -828,7 +857,7 @@ async def _batch_docs( if endpoint_configuration.is_stream_enabled: async def _stream_docs( - stream_request: Annotated[StreamRequest, StreamRequest], + stream_request: Annotated[StreamRequest, Body()], config_hash: str = "", ) -> EventSourceResponse: """Invoke the runnable stream the output. @@ -912,7 +941,7 @@ async def _stream_docs( if endpoint_configuration.is_stream_log_enabled: async def _stream_log_docs( - stream_log_request: Annotated[StreamLogRequest, StreamLogRequest], + stream_log_request: Annotated[StreamLogRequest, Body()], config_hash: str = "", ) -> EventSourceResponse: """Invoke the runnable stream_log the output. @@ -986,7 +1015,7 @@ async def _stream_log_docs( if has_astream_events and endpoint_configuration.is_stream_events_enabled: async def _stream_events_docs( - stream_events_request: Annotated[StreamEventsRequest, StreamEventsRequest], + stream_events_request: Annotated[StreamEventsRequest, Body()], config_hash: str = "", ) -> EventSourceResponse: """Stream events from the given runnable.