Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions weavel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
WsLocalGlobalMetricRequest,
WsLocalMetricRequest,
WsLocalTask,
WsServerOptimizeResponse,
WsServerTask,
)

Expand Down Expand Up @@ -946,15 +945,15 @@ def _get_global_metric(self) -> Optional[BaseGlobalMetric]:
def _set_global_metric(self, global_metric: Optional[BaseGlobalMetric]):
self._global_metric_var.set(global_metric)

@websocket_handler(WsLocalTask.GENERATE.value)
@websocket_handler(WsLocalTask.GENERATE)
async def handle_generation_request(self, data: WsLocalGenerateRequest):
logger.debug("Handling generation request...")
generator = self._get_generator()
if not generator:
raise AttributeError("Generate not set")
return await generator(prompt=Prompt(**data["prompt"]), inputs=data["inputs"])

@websocket_handler(WsLocalTask.EVALUATE.value)
@websocket_handler(WsLocalTask.EVALUATE)
async def handle_evaluation_request(
self, data: WsLocalEvaluateRequest
) -> WsLocalEvaluateResponse:
Expand Down Expand Up @@ -988,7 +987,7 @@ async def handle_evaluation_request(
"global_result": global_result.model_dump(),
}

@websocket_handler(WsLocalTask.METRIC.value)
@websocket_handler(WsLocalTask.METRIC)
async def handle_metric_request(self, data: WsLocalMetricRequest):
logger.debug("Handling metric request...")
metric = self._get_metric()
Expand All @@ -997,7 +996,7 @@ async def handle_metric_request(self, data: WsLocalMetricRequest):
res = await metric(dataset_item=data["dataset_item"], pred=data["pred"])
return res.model_dump()

@websocket_handler(WsLocalTask.GLOBAL_METRIC.value)
@websocket_handler(WsLocalTask.GLOBAL_METRIC)
async def handle_global_metric_request(self, data: WsLocalGlobalMetricRequest):
logger.debug("Handling global metric request...")
global_metric = self._get_global_metric()
Expand All @@ -1007,7 +1006,7 @@ async def handle_global_metric_request(self, data: WsLocalGlobalMetricRequest):
res = await global_metric(results=results)
return res.model_dump()

@websocket_handler(WsServerTask.OPTIMIZE.value)
@websocket_handler(WsServerTask.OPTIMIZE)
async def handle_optimization_result(self, data: Dict[str, Any]):
# Extract the correlation_id from the response data
correlation_id = data.get("correlation_id")
Expand Down Expand Up @@ -1094,10 +1093,11 @@ async def optimize(
raise ValueError(
"base_prompt must be either a Prompt or WvPromptVersion object"
)
prompt_name = wv_prompt.name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will raise an error when base_prompt is an instance of WvPromptVersion


dataset_created = False
if not isinstance(trainset, WvDataset):
dataset_name = f"trainset-{uuid4()}"
dataset_name = f"prompt-{prompt_name}-trainset-{datetime.now().strftime('%Y-%m-%d_%H-%M')}"
dataset = await self.acreate_dataset(name=dataset_name)
dataset_created = True
dataset_items = [
Expand Down
8 changes: 5 additions & 3 deletions weavel/clients/websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@ def relevant_message_types(self) -> List[str]:
Add all relevant message types that should reset the timeout here.
"""
return [
WsLocalTask.GENERATE.value,
WsLocalTask.EVALUATE.value,
WsLocalTask.GENERATE,
WsLocalTask.EVALUATE,
WsLocalTask.METRIC,
WsLocalTask.GLOBAL_METRIC,
# Add other message types as needed
]

Expand Down Expand Up @@ -370,7 +372,7 @@ async def request(self, type: WsServerTask, data: Dict[str, Any] = {}):

message = {
"correlation_id": correlation_id,
"type": type.value,
"type": type,
"data": data,
}
try:
Expand Down
12 changes: 4 additions & 8 deletions weavel/types/websocket.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
from enum import StrEnum
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Union
from typing_extensions import TypedDict
from openai.types.chat.completion_create_params import ChatCompletionMessageParam
from ape.common.types import DatasetItem, MetricResult, GlobalMetricResult
class WsLocalTask(StrEnum):

class WsLocalTask(str, Enum):
GENERATE = "GENERATE"
EVALUATE = "EVALUATE"
METRIC = "METRIC"
GLOBAL_METRIC = "GLOBAL_METRIC"


class WsServerTask(StrEnum):
class WsServerTask(str, Enum):
OPTIMIZE = "OPTIMIZE"


class WsServerOptimizeResponse(StrEnum):
OPTIMIZATION_COMPLETE = "OPTIMIZATION_COMPLETE"


class BaseWsLocalRequest(TypedDict):
type: WsLocalTask
correlation_id: str
Expand Down