diff --git a/weavel/client.py b/weavel/client.py index af893ee..8fc987a 100644 --- a/weavel/client.py +++ b/weavel/client.py @@ -49,7 +49,6 @@ WsLocalGlobalMetricRequest, WsLocalMetricRequest, WsLocalTask, - WsServerOptimizeResponse, WsServerTask, ) @@ -946,7 +945,7 @@ def _get_global_metric(self) -> Optional[BaseGlobalMetric]: def _set_global_metric(self, global_metric: Optional[BaseGlobalMetric]): self._global_metric_var.set(global_metric) - @websocket_handler(WsLocalTask.GENERATE.value) + @websocket_handler(WsLocalTask.GENERATE) async def handle_generation_request(self, data: WsLocalGenerateRequest): logger.debug("Handling generation request...") generator = self._get_generator() @@ -954,7 +953,7 @@ async def handle_generation_request(self, data: WsLocalGenerateRequest): raise AttributeError("Generate not set") return await generator(prompt=Prompt(**data["prompt"]), inputs=data["inputs"]) - @websocket_handler(WsLocalTask.EVALUATE.value) + @websocket_handler(WsLocalTask.EVALUATE) async def handle_evaluation_request( self, data: WsLocalEvaluateRequest ) -> WsLocalEvaluateResponse: @@ -988,7 +987,7 @@ async def handle_evaluation_request( "global_result": global_result.model_dump(), } - @websocket_handler(WsLocalTask.METRIC.value) + @websocket_handler(WsLocalTask.METRIC) async def handle_metric_request(self, data: WsLocalMetricRequest): logger.debug("Handling metric request...") metric = self._get_metric() @@ -997,7 +996,7 @@ async def handle_metric_request(self, data: WsLocalMetricRequest): res = await metric(dataset_item=data["dataset_item"], pred=data["pred"]) return res.model_dump() - @websocket_handler(WsLocalTask.GLOBAL_METRIC.value) + @websocket_handler(WsLocalTask.GLOBAL_METRIC) async def handle_global_metric_request(self, data: WsLocalGlobalMetricRequest): logger.debug("Handling global metric request...") global_metric = self._get_global_metric() @@ -1007,7 +1006,7 @@ async def handle_global_metric_request(self, data: WsLocalGlobalMetricRequest): res = await global_metric(results=results) return res.model_dump() - @websocket_handler(WsServerTask.OPTIMIZE.value) + @websocket_handler(WsServerTask.OPTIMIZE) async def handle_optimization_result(self, data: Dict[str, Any]): # Extract the correlation_id from the response data correlation_id = data.get("correlation_id") @@ -1094,10 +1093,11 @@ async def optimize( raise ValueError( "base_prompt must be either a Prompt or WvPromptVersion object" ) + prompt_name = wv_prompt.name dataset_created = False if not isinstance(trainset, WvDataset): - dataset_name = f"trainset-{uuid4()}" + dataset_name = f"prompt-{prompt_name}-trainset-{datetime.now().strftime('%Y-%m-%d_%H-%M')}" dataset = await self.acreate_dataset(name=dataset_name) dataset_created = True dataset_items = [ diff --git a/weavel/clients/websocket_client.py b/weavel/clients/websocket_client.py index 46fa856..50ae6d3 100644 --- a/weavel/clients/websocket_client.py +++ b/weavel/clients/websocket_client.py @@ -210,8 +210,10 @@ def relevant_message_types(self) -> List[str]: Add all relevant message types that should reset the timeout here. """ return [ - WsLocalTask.GENERATE.value, - WsLocalTask.EVALUATE.value, + WsLocalTask.GENERATE, + WsLocalTask.EVALUATE, + WsLocalTask.METRIC, + WsLocalTask.GLOBAL_METRIC, # Add other message types as needed ] @@ -370,7 +372,7 @@ async def request(self, type: WsServerTask, data: Dict[str, Any] = {}): message = { "correlation_id": correlation_id, - "type": type.value, + "type": type, "data": data, } try: diff --git a/weavel/types/websocket.py b/weavel/types/websocket.py index 1884f9e..acd3a82 100644 --- a/weavel/types/websocket.py +++ b/weavel/types/websocket.py @@ -1,23 +1,19 @@ -from enum import StrEnum +from enum import Enum from typing import Any, Dict, Iterable, List, Optional, Union from typing_extensions import TypedDict from openai.types.chat.completion_create_params import ChatCompletionMessageParam from ape.common.types import DatasetItem, MetricResult, GlobalMetricResult -class WsLocalTask(StrEnum): + +class WsLocalTask(str, Enum): GENERATE = "GENERATE" EVALUATE = "EVALUATE" METRIC = "METRIC" GLOBAL_METRIC = "GLOBAL_METRIC" -class WsServerTask(StrEnum): +class WsServerTask(str, Enum): OPTIMIZE = "OPTIMIZE" - -class WsServerOptimizeResponse(StrEnum): - OPTIMIZATION_COMPLETE = "OPTIMIZATION_COMPLETE" - - class BaseWsLocalRequest(TypedDict): type: WsLocalTask correlation_id: str