-
Notifications
You must be signed in to change notification settings - Fork 4
Update_testing #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Update_testing #15
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,16 +79,19 @@ class LiteLLM(LiteLLMBase): | |
| def invoke(self, payload, **kwargs): | ||
| try: | ||
| response = completion(model=self.litellm_model, **payload, **kwargs) | ||
| assert isinstance(response, ModelResponse) | ||
| if not isinstance(response, ModelResponse): | ||
| raise ValueError(f"Expected ModelResponse, got {type(response)}") | ||
| response = self._parse_converse_response(response) | ||
| response.input_prompt = self._parse_payload(payload) | ||
| return response | ||
|
|
||
| except Exception as e: | ||
| logger.exception(e) | ||
| return InvocationResponse.error_output( | ||
| id=uuid4().hex, error=str(e), input_prompt=self._parse_payload(payload) | ||
| response = InvocationResponse.error_output( | ||
| input_payload=payload, error=e, id=uuid4().hex | ||
| ) | ||
| response.input_prompt = self._parse_payload(payload) | ||
| return response | ||
|
|
||
| def _parse_converse_response( | ||
| self, client_response: ModelResponse | ||
|
|
@@ -102,29 +105,52 @@ def _parse_converse_response( | |
| response.num_tokens_input = usage.prompt_tokens | ||
| response.num_tokens_output = usage.completion_tokens | ||
| except AttributeError: | ||
| pass | ||
| response.num_tokens_input = None | ||
| response.num_tokens_output = None | ||
|
Comment on lines
-105
to
+109
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at the |
||
|
|
||
| return response | ||
|
|
||
|
|
||
| class LiteLLMStreaming(LiteLLMBase): | ||
| def invoke(self, payload, **kwargs): | ||
| if ("stream" not in kwargs) or ("stream" not in payload): | ||
| kwargs["stream"] = True | ||
|
|
||
| if ("stream_options" not in kwargs) or ("stream_options" not in payload): | ||
| kwargs["stream_options"] = {"include_usage": True} | ||
| # Make a copy of payload to avoid modifying the original | ||
| payload_copy = payload.copy() | ||
|
|
||
| # Create a clean kwargs dict without conflicting parameters | ||
| clean_kwargs = {} | ||
| for key, value in kwargs.items(): | ||
| if key not in ["stream", "stream_options"]: | ||
| clean_kwargs[key] = value | ||
|
|
||
| # Ensure streaming is enabled | ||
| payload_copy["stream"] = True | ||
|
|
||
| # Handle stream_options - merge if exists in kwargs, otherwise set default | ||
| if "stream_options" in kwargs: | ||
| existing_options = kwargs.get("stream_options", {}) | ||
| payload_copy["stream_options"] = {**existing_options, "include_usage": True} | ||
|
Comment on lines
+130
to
+131
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we merge in case there are some Could be e.g. existing_kwargs_options = kwargs["stream_options"]
existing_payload_options = payload_copy.get("stream_options", {})
payload_copy["stream_options"] = {
**existing_payload_options,
**existing_kwargs_options,
"include_usage": True,
} |
||
| elif "stream_options" not in payload_copy: | ||
| payload_copy["stream_options"] = {"include_usage": True} | ||
| else: | ||
| # Merge with existing stream_options in payload if present | ||
| existing_options = payload_copy.get("stream_options", {}) | ||
| payload_copy["stream_options"] = {**existing_options, "include_usage": True} | ||
|
|
||
| try: | ||
| start_t = time.perf_counter() | ||
| response = completion(model=self.litellm_model, **payload, **kwargs) | ||
| response = completion( | ||
| model=self.litellm_model, **payload_copy, **clean_kwargs | ||
| ) | ||
| except Exception as e: | ||
| logger.exception(e) | ||
| return InvocationResponse.error_output( | ||
| id=uuid4().hex, error=str(e), input_prompt=self._parse_payload(payload) | ||
| response = InvocationResponse.error_output( | ||
| input_payload=payload, error=e, id=uuid4().hex | ||
| ) | ||
| response.input_prompt = self._parse_payload(payload) | ||
| return response | ||
|
|
||
| assert isinstance(response, CustomStreamWrapper) | ||
| if not isinstance(response, CustomStreamWrapper): | ||
| raise ValueError(f"Expected CustomStreamWrapper, got {type(response)}") | ||
| response = self._parse_stream(response, start_t) | ||
| response.input_prompt = self._parse_payload(payload) | ||
| return response | ||
|
|
@@ -136,12 +162,21 @@ def _parse_stream( | |
| time_flag = True | ||
| time_to_first_token = None | ||
| output_text = "" | ||
| id = None | ||
|
|
||
| for chunk in client_response: | ||
| output_text += chunk.choices[0].delta.content or "" # type: ignore | ||
| if time_flag: | ||
| content = chunk.choices[0].delta.content or "" # type: ignore | ||
| output_text += content | ||
|
|
||
| # Record time to first token only when we get actual content | ||
| if time_flag and content: | ||
| time_to_first_token = time.perf_counter() - start_t | ||
| time_flag = False | ||
|
|
||
| # Always capture the ID from the first chunk | ||
| if id is None: | ||
| id = chunk.id | ||
|
|
||
| try: | ||
| usage = chunk.usage # type: ignore | ||
| except AttributeError: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,12 @@ class DeferredError: | |
| """ | ||
|
|
||
| def __init__(self, exception): | ||
| self.exc = exception | ||
| # Ensure the exception is a BaseException instance | ||
| if isinstance(exception, BaseException): | ||
| self.exc = exception | ||
| else: | ||
| # If it's not a BaseException, wrap it in an ImportError | ||
| self.exc = ImportError(str(exception)) | ||
|
Comment on lines
+30
to
+32
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although we use it only for
My asks would probably be to:
|
||
|
|
||
| def __getattr__(self, name): | ||
| """Called by Python interpreter before using any method or property on the object. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For me, |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like
runner.py,endpoints/bedrock.py,endpoints/openai.py, andendpoints/sagemaker.pyall still have some cases usingerror=str(e)- do we care enough to fix that consistently and add anException | str | Nonetype annotation toInvocationResponse.error_output()'s definition?