diff --git a/camel/societies/workforce/events.py b/camel/societies/workforce/events.py index 2800739319..45786854a3 100644 --- a/camel/societies/workforce/events.py +++ b/camel/societies/workforce/events.py @@ -32,6 +32,7 @@ class WorkforceEventBase(BaseModel): "worker_deleted", "queue_status", "all_tasks_completed", + "task_streaming_chunk", ] metadata: Optional[Dict[str, Any]] = None timestamp: datetime = Field( @@ -96,6 +97,14 @@ class TaskFailedEvent(WorkforceEventBase): worker_id: Optional[str] = None +class TaskStreamingChunkEvent(WorkforceEventBase): + event_type: Literal["task_streaming_chunk"] = "task_streaming_chunk" + task_id: str + worker_id: str + chunk: str + chunk_index: int + + class AllTasksCompletedEvent(WorkforceEventBase): event_type: Literal["all_tasks_completed"] = "all_tasks_completed" @@ -109,6 +118,7 @@ class QueueStatusEvent(WorkforceEventBase): WorkforceEvent = Union[ + TaskStreamingChunkEvent, TaskDecomposedEvent, TaskCreatedEvent, TaskAssignedEvent, diff --git a/camel/societies/workforce/single_agent_worker.py b/camel/societies/workforce/single_agent_worker.py index 752ecbb18c..3d02336562 100644 --- a/camel/societies/workforce/single_agent_worker.py +++ b/camel/societies/workforce/single_agent_worker.py @@ -24,6 +24,7 @@ from camel.agents import ChatAgent from camel.agents.chat_agent import AsyncStreamingChatAgentResponse from camel.logger import get_logger +from camel.societies.workforce.events import TaskStreamingChunkEvent from camel.societies.workforce.prompts import PROCESS_TASK_PROMPT from camel.societies.workforce.structured_output_handler import ( StructuredOutputHandler, @@ -393,9 +394,32 @@ async def _process_task( # Handle streaming response if isinstance(response, AsyncStreamingChatAgentResponse): content = "" + chunk_index = 0 async for chunk in response: - if chunk.msg: - content = chunk.msg.content + if chunk.msg and chunk.msg.content: + chunk_event = TaskStreamingChunkEvent( + task_id=task.id, + worker_id=self.node_id, + chunk=chunk.msg.content, + chunk_index=chunk_index, + ) + if ( + hasattr(self, 'callback') + and self.callback is not None + ): + try: + await ( + self.callback.log_task_streaming_chunk( + chunk_event + ) + ) + except Exception as e: + logger.warning( + f"Failed to log streaming chunk: {e}" + ) + + content += chunk.msg.content + chunk_index += 1 response_content = content else: # Regular ChatAgentResponse @@ -422,10 +446,34 @@ async def _process_task( # Handle streaming response for native output if isinstance(response, AsyncStreamingChatAgentResponse): task_result = None + chunk_index = 0 async for chunk in response: - if chunk.msg and chunk.msg.parsed: - task_result = chunk.msg.parsed - response_content = chunk.msg.content + if chunk.msg: + if chunk.msg.content: + chunk_event = TaskStreamingChunkEvent( + task_id=task.id, + worker_id=self.node_id, + chunk=chunk.msg.content, + chunk_index=chunk_index, + ) + if ( + hasattr(self, 'callback') + and self.callback is not None + ): + try: + await self.callback.log_task_streaming_chunk( # noqa: E501 + chunk_event + ) + except Exception as e: + logger.warning( + f"Failed to log chunk: {e}" + ) + chunk_index += 1 + + if chunk.msg.parsed: + task_result = chunk.msg.parsed + response_content = chunk.msg.content + # If no parsed result found in streaming, create fallback if task_result is None: task_result = TaskResult( diff --git a/camel/societies/workforce/workforce_callback.py b/camel/societies/workforce/workforce_callback.py index 9dee1d559a..faf144cde1 100644 --- a/camel/societies/workforce/workforce_callback.py +++ b/camel/societies/workforce/workforce_callback.py @@ -23,6 +23,7 @@ TaskDecomposedEvent, TaskFailedEvent, TaskStartedEvent, + TaskStreamingChunkEvent, WorkerCreatedEvent, WorkerDeletedEvent, ) @@ -72,3 +73,7 @@ def log_worker_deleted(self, event: WorkerDeletedEvent) -> None: @abstractmethod def log_all_tasks_completed(self, event: AllTasksCompletedEvent) -> None: pass + + @abstractmethod + def log_task_streaming_chunk(self, event: TaskStreamingChunkEvent) -> None: + pass diff --git a/camel/societies/workforce/workforce_logger.py b/camel/societies/workforce/workforce_logger.py index b04c134288..9e8d9936b3 100644 --- a/camel/societies/workforce/workforce_logger.py +++ b/camel/societies/workforce/workforce_logger.py @@ -25,6 +25,7 @@ TaskDecomposedEvent, TaskFailedEvent, TaskStartedEvent, + TaskStreamingChunkEvent, WorkerCreatedEvent, WorkerDeletedEvent, ) @@ -49,6 +50,7 @@ def __init__(self, workforce_id: str): self._task_hierarchy: Dict[str, Dict[str, Any]] = {} self._worker_information: Dict[str, Dict[str, Any]] = {} self._initial_worker_logs: List[Dict[str, Any]] = [] + self._streaming_chunks: Dict[str, List[Dict[str, Any]]] = {} def _log_event(self, event_type: str, **kwargs: Any) -> None: r"""Internal method to create and store a log entry. @@ -67,6 +69,31 @@ def _log_event(self, event_type: str, **kwargs: Any) -> None: if event_type == 'worker_created': self._initial_worker_logs.append(log_entry) + def log_task_streaming_chunk(self, event: TaskStreamingChunkEvent) -> None: + r"""Logs a streaming chunk from task execution. + + Args: + event (TaskStreamingChunkEvent): The streaming chunk event. + """ + if event.task_id not in self._streaming_chunks: + self._streaming_chunks[event.task_id] = [] + + chunk_data = { + 'chunk_index': event.chunk_index, + 'chunk': event.chunk, + 'worker_id': event.worker_id, + } + self._streaming_chunks[event.task_id].append(chunk_data) + + self._log_event( + event_type=event.event_type, + task_id=event.task_id, + worker_id=event.worker_id, + chunk=event.chunk, + chunk_index=event.chunk_index, + metadata=event.metadata or {}, + ) + def log_task_created( self, event: TaskCreatedEvent,