From ab87dad5e8f6c296d32c099624f8f731193e5563 Mon Sep 17 00:00:00 2001 From: engineerA314 Date: Sat, 5 Oct 2024 13:05:16 -0700 Subject: [PATCH 1/3] fix: fix StrEnum -> Enum --- weavel/client.py | 3 +-- weavel/clients/websocket_client.py | 4 +++- weavel/types/websocket.py | 12 ++++-------- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/weavel/client.py b/weavel/client.py index af893ee..836e098 100644 --- a/weavel/client.py +++ b/weavel/client.py @@ -49,7 +49,6 @@ WsLocalGlobalMetricRequest, WsLocalMetricRequest, WsLocalTask, - WsServerOptimizeResponse, WsServerTask, ) @@ -1131,7 +1130,7 @@ async def optimize( ): async with self.ws_client: res = await self.ws_client.request( - type=WsServerTask.OPTIMIZE, + type=WsServerTask.OPTIMIZE.value, data={ "base_prompt_version_uuid": wv_prompt_version.uuid, "models": models, diff --git a/weavel/clients/websocket_client.py b/weavel/clients/websocket_client.py index 46fa856..c7babb4 100644 --- a/weavel/clients/websocket_client.py +++ b/weavel/clients/websocket_client.py @@ -212,6 +212,8 @@ def relevant_message_types(self) -> List[str]: return [ WsLocalTask.GENERATE.value, WsLocalTask.EVALUATE.value, + WsLocalTask.METRIC.value, + WsLocalTask.GLOBAL_METRIC.value, # Add other message types as needed ] @@ -370,7 +372,7 @@ async def request(self, type: WsServerTask, data: Dict[str, Any] = {}): message = { "correlation_id": correlation_id, - "type": type.value, + "type": type.value if isinstance(type, WsServerTask) else type, "data": data, } try: diff --git a/weavel/types/websocket.py b/weavel/types/websocket.py index 1884f9e..3322688 100644 --- a/weavel/types/websocket.py +++ b/weavel/types/websocket.py @@ -1,23 +1,19 @@ -from enum import StrEnum +from enum import Enum from typing import Any, Dict, Iterable, List, Optional, Union from typing_extensions import TypedDict from openai.types.chat.completion_create_params import ChatCompletionMessageParam from ape.common.types import DatasetItem, MetricResult, GlobalMetricResult -class WsLocalTask(StrEnum): + +class WsLocalTask(Enum): GENERATE = "GENERATE" EVALUATE = "EVALUATE" METRIC = "METRIC" GLOBAL_METRIC = "GLOBAL_METRIC" -class WsServerTask(StrEnum): +class WsServerTask(Enum): OPTIMIZE = "OPTIMIZE" - -class WsServerOptimizeResponse(StrEnum): - OPTIMIZATION_COMPLETE = "OPTIMIZATION_COMPLETE" - - class BaseWsLocalRequest(TypedDict): type: WsLocalTask correlation_id: str From ae91ec05b16700855c5932e574899167020fd0bb Mon Sep 17 00:00:00 2001 From: engineerA314 Date: Tue, 8 Oct 2024 13:59:19 -0700 Subject: [PATCH 2/3] chore: use str, Enum instead of Enum --- weavel/client.py | 12 ++++++------ weavel/clients/websocket_client.py | 10 +++++----- weavel/types/websocket.py | 4 ++-- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/weavel/client.py b/weavel/client.py index 836e098..eff20ce 100644 --- a/weavel/client.py +++ b/weavel/client.py @@ -945,7 +945,7 @@ def _get_global_metric(self) -> Optional[BaseGlobalMetric]: def _set_global_metric(self, global_metric: Optional[BaseGlobalMetric]): self._global_metric_var.set(global_metric) - @websocket_handler(WsLocalTask.GENERATE.value) + @websocket_handler(WsLocalTask.GENERATE) async def handle_generation_request(self, data: WsLocalGenerateRequest): logger.debug("Handling generation request...") generator = self._get_generator() @@ -953,7 +953,7 @@ async def handle_generation_request(self, data: WsLocalGenerateRequest): raise AttributeError("Generate not set") return await generator(prompt=Prompt(**data["prompt"]), inputs=data["inputs"]) - @websocket_handler(WsLocalTask.EVALUATE.value) + @websocket_handler(WsLocalTask.EVALUATE) async def handle_evaluation_request( self, data: WsLocalEvaluateRequest ) -> WsLocalEvaluateResponse: @@ -987,7 +987,7 @@ async def handle_evaluation_request( "global_result": global_result.model_dump(), } - @websocket_handler(WsLocalTask.METRIC.value) + @websocket_handler(WsLocalTask.METRIC) async def handle_metric_request(self, data: WsLocalMetricRequest): logger.debug("Handling metric request...") metric = self._get_metric() @@ -996,7 +996,7 @@ async def handle_metric_request(self, data: WsLocalMetricRequest): res = await metric(dataset_item=data["dataset_item"], pred=data["pred"]) return res.model_dump() - @websocket_handler(WsLocalTask.GLOBAL_METRIC.value) + @websocket_handler(WsLocalTask.GLOBAL_METRIC) async def handle_global_metric_request(self, data: WsLocalGlobalMetricRequest): logger.debug("Handling global metric request...") global_metric = self._get_global_metric() @@ -1006,7 +1006,7 @@ async def handle_global_metric_request(self, data: WsLocalGlobalMetricRequest): res = await global_metric(results=results) return res.model_dump() - @websocket_handler(WsServerTask.OPTIMIZE.value) + @websocket_handler(WsServerTask.OPTIMIZE) async def handle_optimization_result(self, data: Dict[str, Any]): # Extract the correlation_id from the response data correlation_id = data.get("correlation_id") @@ -1130,7 +1130,7 @@ async def optimize( ): async with self.ws_client: res = await self.ws_client.request( - type=WsServerTask.OPTIMIZE.value, + type=WsServerTask.OPTIMIZE, data={ "base_prompt_version_uuid": wv_prompt_version.uuid, "models": models, diff --git a/weavel/clients/websocket_client.py b/weavel/clients/websocket_client.py index c7babb4..50ae6d3 100644 --- a/weavel/clients/websocket_client.py +++ b/weavel/clients/websocket_client.py @@ -210,10 +210,10 @@ def relevant_message_types(self) -> List[str]: Add all relevant message types that should reset the timeout here. """ return [ - WsLocalTask.GENERATE.value, - WsLocalTask.EVALUATE.value, - WsLocalTask.METRIC.value, - WsLocalTask.GLOBAL_METRIC.value, + WsLocalTask.GENERATE, + WsLocalTask.EVALUATE, + WsLocalTask.METRIC, + WsLocalTask.GLOBAL_METRIC, # Add other message types as needed ] @@ -372,7 +372,7 @@ async def request(self, type: WsServerTask, data: Dict[str, Any] = {}): message = { "correlation_id": correlation_id, - "type": type.value if isinstance(type, WsServerTask) else type, + "type": type, "data": data, } try: diff --git a/weavel/types/websocket.py b/weavel/types/websocket.py index 3322688..acd3a82 100644 --- a/weavel/types/websocket.py +++ b/weavel/types/websocket.py @@ -4,14 +4,14 @@ from openai.types.chat.completion_create_params import ChatCompletionMessageParam from ape.common.types import DatasetItem, MetricResult, GlobalMetricResult -class WsLocalTask(Enum): +class WsLocalTask(str, Enum): GENERATE = "GENERATE" EVALUATE = "EVALUATE" METRIC = "METRIC" GLOBAL_METRIC = "GLOBAL_METRIC" -class WsServerTask(Enum): +class WsServerTask(str, Enum): OPTIMIZE = "OPTIMIZE" class BaseWsLocalRequest(TypedDict): From 3df3515e0af84a30fdbbfe85a7f526a73388455a Mon Sep 17 00:00:00 2001 From: engineerA314 Date: Sun, 6 Oct 2024 15:06:18 -0700 Subject: [PATCH 3/3] chore: rename dataset name --- weavel/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/weavel/client.py b/weavel/client.py index eff20ce..8fc987a 100644 --- a/weavel/client.py +++ b/weavel/client.py @@ -1093,10 +1093,11 @@ async def optimize( raise ValueError( "base_prompt must be either a Prompt or WvPromptVersion object" ) + prompt_name = wv_prompt.name dataset_created = False if not isinstance(trainset, WvDataset): - dataset_name = f"trainset-{uuid4()}" + dataset_name = f"prompt-{prompt_name}-trainset-{datetime.now().strftime('%Y-%m-%d_%H-%M')}" dataset = await self.acreate_dataset(name=dataset_name) dataset_created = True dataset_items = [