diff --git a/any_llm_client/clients/openai.py b/any_llm_client/clients/openai.py index 2562bb1..034a662 100644 --- a/any_llm_client/clients/openai.py +++ b/any_llm_client/clients/openai.py @@ -166,6 +166,12 @@ def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn raise LLMError(response_content=content) +def _handle_validation_error(*, content: bytes, original_error: pydantic.ValidationError) -> typing.NoReturn: + if b"is too long to fit into the model" in content: # vLLM + raise OutOfTokensOrSymbolsError(response_content=content) + raise LLMResponseValidationError(response_content=content, original_error=original_error) + + @dataclasses.dataclass(slots=True, init=False) class OpenAIClient(LLMClient): config: OpenAIConfig @@ -243,9 +249,7 @@ async def request_llm_message( ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message ) except pydantic.ValidationError as validation_error: - raise LLMResponseValidationError( - response_content=response.content, original_error=validation_error - ) from validation_error + _handle_validation_error(content=response.content, original_error=validation_error) finally: await response.aclose() @@ -262,9 +266,7 @@ async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncI try: validated_response = ChatCompletionsStreamingEvent.model_validate_json(event.data) except pydantic.ValidationError as validation_error: - raise LLMResponseValidationError( - response_content=event.data.encode(), original_error=validation_error - ) from validation_error + _handle_validation_error(content=event.data.encode(), original_error=validation_error) if not ( (validated_delta := validated_response.choices[0].delta) diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py index 566f4ed..ef79c8e 100644 --- a/tests/test_openai_client.py +++ b/tests/test_openai_client.py @@ -149,7 +149,7 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> b'{"object":"error","message":"This model\'s maximum context length is 16384 tokens. However, you requested 100000 tokens in the messages, Please reduce the length of the messages.","type":"BadRequestError","param":null,"code":400}', # noqa: E501 ], ) - async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes | None) -> None: + async def test_fails_with_out_of_tokens_error_on_status(self, stream: bool, content: bytes) -> None: response: typing.Final = httpx.Response(400, content=content) client: typing.Final = any_llm_client.get_client( OpenAIConfigFactory.build(), @@ -165,6 +165,34 @@ async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes with pytest.raises(any_llm_client.OutOfTokensOrSymbolsError): await coroutine + @pytest.mark.parametrize("stream", [True, False]) + @pytest.mark.parametrize( + "content", + [ + b'{"error": {"object": "error", "message": "The prompt (total length 6287) is too long to fit into the model (context length 4096). Make sure that `max_model_len` is no smaller than the number of text tokens plus multimodal tokens. For image inputs, the number of image tokens depends on the number of images, and possibly their aspect ratios as well.", "type": "BadRequestError", "param": null, "code": 400}}\n', # noqa: E501 + b'{"object": "error", "message": "The prompt (total length 43431) is too long to fit into the model (context length 8192). Make sure that `max_model_len` is no smaller than the number of text tokens plus multimodal tokens. For image inputs, the number of image tokens depends on the number of images, and possibly their aspect ratios as well.", "type": "BadRequestError", "param": null, "code": 400}\n', # noqa: E501 + ], + ) + async def test_fails_with_out_of_tokens_error_on_validation(self, stream: bool, content: bytes) -> None: + response: typing.Final = httpx.Response( + 200, + content=f"data: {content.decode()}\n\n" if stream else content, + headers={"Content-Type": "text/event-stream"} if stream else None, + ) + client: typing.Final = any_llm_client.get_client( + OpenAIConfigFactory.build(), + transport=httpx.MockTransport(lambda _: response), + ) + + coroutine: typing.Final = ( + consume_llm_message_chunks(client.stream_llm_message_chunks(**LLMFuncRequestFactory.build())) + if stream + else client.request_llm_message(**LLMFuncRequestFactory.build()) + ) + + with pytest.raises(any_llm_client.OutOfTokensOrSymbolsError): + await coroutine + class TestOpenAIMessageAlternation: @pytest.mark.parametrize(