diff --git a/AsyncFunction_Design.md b/AsyncFunction_Design.md new file mode 100644 index 00000000..42f39096 --- /dev/null +++ b/AsyncFunction_Design.md @@ -0,0 +1,740 @@ +# AsyncFunction Design and Architecture + +## Overview + +This document outlines the design and implementation plan for AsyncFunction support in the Microsoft Teams Python SDK. AsyncFunctions enable functions to suspend execution, wait for external input (user responses, webhooks, timers), and resume seamlessly while maintaining conversation state. + +## Core Concept + +AsyncFunctions can **suspend** during execution and be **resumed** later with additional input. When a function suspends, it returns a `DeferredResult` containing: +- The suspension state (for later resume) +- A handler specifying what type of interaction is needed +- The current conversation pauses until external input is provided + +## Architecture Overview + +```mermaid +graph TD + A[User Input] --> B[ChatPrompt.send()] + B --> C{Suspended Functions?} + C -->|Yes| D[Resume AsyncFunction] + C -->|No| E[Normal Processing] + + D --> F[AsyncFunction.resume()] + F --> G{Still Suspended?} + G -->|Yes| H[Return DeferredResult] + G -->|No| I[Create FunctionMessage] + + E --> J[AIModel.generate_text()] + I --> J + J --> K{Function Calls?} + K -->|Yes| L[Execute Functions] + K -->|No| M[Return ModelMessage] + + L --> N{DeferredResult?} + N -->|Yes| O[Store Suspension State] + N -->|No| P[Continue Recursively] + + O --> H + P --> K + H --> Q[Return to User] + M --> Q +``` + +## Core Types and Interfaces + +### 1. Handler Types + +```python +from dataclasses import dataclass +from typing import Literal, Union + +@dataclass +class AskUserHandler: + type: Literal["ask_user"] = "ask_user" + question: str + +@dataclass +class GetApprovalHandler: + type: Literal["get_approval"] = "get_approval" + prompt: str + +@dataclass +class SelectFromOptionsHandler: + type: Literal["select_options"] = "select_options" + question: str + options: list[str] + +@dataclass +class WebhookHandler: + type: Literal["webhook"] = "webhook" + webhook_url: str + +@dataclass +class TimerHandler: + type: Literal["timer"] = "timer" + delay_seconds: int + message: str + +# Union of all handler types +DeferredHandler = Union[ + AskUserHandler, + GetApprovalHandler, + SelectFromOptionsHandler, + WebhookHandler, + TimerHandler +] +``` + +### 2. Generic DeferredResult + +```python +from typing import TypeVar, Generic + +THandler = TypeVar("THandler") +TResumer = TypeVar("TResumer") + +@dataclass +class DeferredResult(Generic[THandler, TResumer]): + type: Literal["deferred"] = "deferred" + state: dict[str, Any] + handler: THandler + +# Type aliases for common patterns +AskUserResult = DeferredResult[AskUserHandler, UserMessage] +GetApprovalResult = DeferredResult[GetApprovalHandler, UserMessage] +WebhookResult = DeferredResult[WebhookHandler, dict[str, Any]] +TimerResult = DeferredResult[TimerHandler, None] + +AnyDeferredResult = Union[ + AskUserResult, + GetApprovalResult, + SelectFromOptionsHandler, + WebhookResult, + TimerResult +] +``` + +### 3. AsyncFunction Protocol + +```python +class AsyncFunctionHandler(Protocol[Params]): + def __call__(self, params: Params) -> Union[str, AnyDeferredResult]: ... + def resume(self, resumer: Any, state: dict[str, Any]) -> Union[str, AnyDeferredResult]: ... + +@dataclass +class AsyncFunction(Generic[Params]): + name: str + description: str + parameter_schema: Union[type[Params], Dict[str, Any]] + handler: AsyncFunctionHandler[Params] +``` + +### 4. Extended FunctionCall + +```python +@dataclass +class FunctionCall: + id: str + name: str + arguments: dict[str, Any] + # New async-specific fields: + type: Literal["sync", "async"] | None = None + status: Literal["running", "suspended", "completed", "failed"] | None = None + state: dict[str, Any] | None = None + handler: DeferredHandler | None = None +``` + +## Implementation Examples + +### 1. Simple AskUser Function + +```python +class ShoppingParams(BaseModel): + items: list[str] + budget: float + +class ShoppingHandler: + def __call__(self, params: ShoppingParams) -> AskUserResult: + return DeferredResult[AskUserHandler, UserMessage]( + state={"items": params.items, "budget": params.budget}, + handler=AskUserHandler(question="What additional item would you like to add?") + ) + + def resume(self, user_message: UserMessage, state: dict[str, Any]) -> str: + user_choice = user_message.content + original_items = state["items"] + return f"Adding '{user_choice}' to cart with {original_items}" + +# Usage +shopping_function = AsyncFunction( + name="add_to_cart", + description="Add items to shopping cart", + parameter_schema=ShoppingParams, + handler=ShoppingHandler() +) +``` + +### 2. Multi-Step Approval Workflow + +```python +class ApprovalParams(BaseModel): + document: str + amount: float + +class DocumentApprovalHandler: + def __call__(self, params: ApprovalParams) -> GetApprovalResult: + return DeferredResult[GetApprovalHandler, UserMessage]( + state={ + "document": params.document, + "amount": params.amount, + "step": "manager_approval" + }, + handler=GetApprovalHandler( + prompt=f"Manager approval needed for ${params.amount} expense: '{params.document}'" + ) + ) + + def resume(self, user_message: UserMessage, state: dict[str, Any]) -> Union[str, GetApprovalResult]: + user_input = user_message.content.lower() + + if state["step"] == "manager_approval": + if "yes" in user_input or "approve" in user_input: + return DeferredResult[GetApprovalHandler, UserMessage]( + state={**state, "step": "director_approval", "manager_approved": True}, + handler=GetApprovalHandler( + prompt=f"Director approval needed for ${state['amount']} expense" + ) + ) + else: + return f"Expense rejected by manager: {state['document']}" + + elif state["step"] == "director_approval": + if "yes" in user_input or "approve" in user_input: + return f"Expense '${state['amount']} - {state['document']}' fully approved!" + else: + return f"Expense rejected by director: {state['document']}" +``` + +## Data Flow Diagrams + +### Normal Function Execution Flow + +```mermaid +sequenceDiagram + participant U as User + participant CP as ChatPrompt + participant M as AIModel + participant F as Function + participant LLM as OpenAI API + + U->>CP: send("Calculate 2+2") + CP->>M: generate_text(UserMessage) + M->>LLM: chat.completions.create() + LLM->>M: ModelMessage with function_calls + M->>F: handler(params) + F->>M: "4" + M->>M: Create FunctionMessage + M->>LLM: chat.completions.create() (recursive) + LLM->>M: ModelMessage("The answer is 4") + M->>CP: ModelMessage + CP->>U: ChatSendResult +``` + +### AsyncFunction Suspension Flow + +```mermaid +sequenceDiagram + participant U as User + participant CP as ChatPrompt + participant M as AIModel + participant AF as AsyncFunction + participant LLM as OpenAI API + participant Mem as Memory + + U->>CP: send("Process payment") + CP->>M: generate_text(UserMessage) + M->>LLM: chat.completions.create() + LLM->>M: ModelMessage with function_calls + M->>AF: handler(params) + AF->>M: DeferredResult(handler=GetApprovalHandler) + M->>CP: DeferredResult + CP->>Mem: store_suspended_call() + CP->>U: ChatSendResult(DeferredResult) + + Note over U: User sees: "Approve payment of $100? (yes/no)" +``` + +### AsyncFunction Resume Flow + +```mermaid +sequenceDiagram + participant U as User + participant CP as ChatPrompt + participant AF as AsyncFunction + participant M as AIModel + participant LLM as OpenAI API + participant Mem as Memory + + U->>CP: send("yes") + CP->>Mem: get_suspended_calls() + Mem->>CP: FunctionCall with state + CP->>AF: resume(UserMessage("yes"), state) + AF->>CP: "Payment processed successfully" + CP->>CP: Create FunctionMessage + CP->>M: generate_text(FunctionMessage) + M->>LLM: chat.completions.create() + LLM->>M: ModelMessage("Your payment is complete!") + M->>CP: ModelMessage + CP->>U: ChatSendResult(ModelMessage) +``` + +## Resume Handler Patterns + +AsyncFunction resume handling varies based on the type of interaction required: + +### 1. **Structured Responses: Dedicated Handler Aliases** + +For AsyncFunctions that expect structured data (cards, webhooks, specific payloads), the system generates dedicated handler aliases: + +```python +# Generated handler aliases for structured responses +class TeamsApp(GeneratedActivityHandlerMixin): + + @on_expense_approval_response # Alias for @on_card_action + async def handle_approval(self, activity: AdaptiveCardInvokeActivity): + """Handle approval card responses""" + action = activity.value.get("action") # "approve", "deny", "request_info" + await self.resume_function("expense_approval", action) + + @on_payment_webhook_response # Alias for @on_invoke + async def handle_payment_webhook(self, activity: InvokeActivity): + """Handle payment webhook callbacks""" + webhook_data = activity.value + await self.resume_function("payment_webhook", webhook_data) + + @on_menu_selection_response # Alias for @on_card_action + async def handle_menu_selection(self, activity: AdaptiveCardInvokeActivity): + """Handle menu option selections""" + selection = activity.value.get("selected_option") + await self.resume_function("menu_selection", selection) +``` + +### 2. **Text Responses: Handle in `@on_message`** + +For AsyncFunctions expecting text responses (human-in-the-loop, open-ended questions), developers handle resume logic in the generic `@on_message` handler: + +```python +class TeamsApp(GeneratedActivityHandlerMixin): + + @on_message + async def handle_all_text_messages(self, activity: MessageActivity): + """Handle both normal chat AND AsyncFunction text responses""" + + # Strategy 1: Check for suspended HITL functions first + if await self.has_suspended_hitl(): + await self.resume_hitl_function(activity.text) + return + + # Strategy 2: Use heuristics to detect HITL responses + if self.looks_like_hitl_response(activity.text): + await self.try_resume_hitl(activity.text) + return + + # Strategy 3: Normal conversation + result = await self.chat.send(activity.text) + await self.send_response(result) + + async def has_suspended_hitl(self) -> bool: + """Check if any human-in-the-loop functions are suspended""" + suspended = await self.get_suspended_functions(handler_type="ask_user") + return len(suspended) > 0 + + async def resume_hitl_function(self, user_text: str): + """Resume suspended HITL function with user response""" + suspended_hitl = await self.get_suspended_hitl_function() + if suspended_hitl: + result = await suspended_hitl.resume(UserMessage(content=user_text), suspended_hitl.state) + await self.handle_resume_result(result) +``` + +### 3. **Handler Generation Strategy** + +The system generates handler aliases based on AsyncFunction handler types: + +```python +# Handler generation mapping +handler_mappings = { + AskUserHandler: None, # No alias - handle in @on_message + GetApprovalHandler: ("on_{function_name}_response", "card.action", AdaptiveCardInvokeActivity), + WebhookHandler: ("on_{function_name}_webhook", "invoke", InvokeActivity), + SelectFromOptionsHandler: ("on_{function_name}_selection", "card.action", AdaptiveCardInvokeActivity), + TimerHandler: None, # No handler needed - auto-resume +} + +# Documentation generated for each AsyncFunction +""" +AsyncFunction Handler Requirements: + +STRUCTURED RESPONSES (get dedicated handlers): +- GetApprovalHandler -> @on_{name}_response (AdaptiveCardInvokeActivity) +- WebhookHandler -> @on_{name}_webhook (InvokeActivity) +- SelectFromOptionsHandler -> @on_{name}_selection (AdaptiveCardInvokeActivity) + +TEXT RESPONSES (handled in @on_message): +- AskUserHandler -> Handle in @on_message with suspend/resume logic +- HumanInTheLoopHandler -> Handle in @on_message with suspend/resume logic + +Example: + +approval_function = AsyncFunction(name="expense_approval", handler=GetApprovalHandler()) +# REQUIRES: @on_expense_approval_response + +hitl_function = AsyncFunction(name="human_in_loop", handler=AskUserHandler()) +# REQUIRES: Logic in @on_message to detect and resume HITL responses +""" +``` + +### 4. **Developer Resume Strategies** + +Developers can choose different strategies for handling ambiguous text message routing: + +**Strategy A: Suspend-First** +```python +@on_message +async def handle(self, activity): + if await self.has_suspended_hitl(): + await self.resume_hitl(activity.text) + else: + await self.normal_chat(activity.text) +``` + +**Strategy B: Keyword Detection** +```python +@on_message +async def handle(self, activity): + if self.is_likely_hitl_response(activity.text): + await self.try_resume_hitl(activity.text) + else: + await self.normal_chat(activity.text) +``` + +**Strategy C: Always Try HITL First** +```python +@on_message +async def handle(self, activity): + if not await self.try_resume_hitl(activity.text): + await self.normal_chat(activity.text) +``` + +## Integration Points + +### 1. ChatPrompt Changes + +```python +class ChatPrompt: + def __init__(self, model: AIModel, *, functions: list[Union[Function[Any], AsyncFunction[Any]]] | None = None): + # Support both Function and AsyncFunction types + self.functions: dict[str, Union[Function[Any], AsyncFunction[Any]]] = { + func.name: func for func in functions + } if functions else {} + + async def send(self, input: str | Message, **kwargs) -> ChatSendResult: + if isinstance(input, str): + input = UserMessage(content=input) + + # Check for suspended functions + suspended_calls = await memory.get_suspended_calls() + + if suspended_calls: + return await self._resume_suspended_function(input, suspended_calls[0]) + else: + return await self._process_new_message(input, **kwargs) + + async def _resume_suspended_function(self, user_input: UserMessage, call: FunctionCall) -> ChatSendResult: + function = self.functions[call.name] + + # Resume the function + result = function.resume(user_input, call.state) + if inspect.isawaitable(result): + result = await result + + if isinstance(result, DeferredResult): + # Function suspended again + await memory.store_suspended_call(call) + return ChatSendResult(response=result) + else: + # Function completed - send result to LLM + function_message = FunctionMessage(content=result, function_id=call.id) + response = await self.model.generate_text(function_message, memory=memory, functions=self.functions) + return ChatSendResult(response=response) +``` + +### 2. AIModel Protocol Update + +```python +class AIModel(Protocol): + async def generate_text( + self, + input: Message, + *, + system: SystemMessage | None = None, + memory: Memory | None = None, + functions: dict[str, Union[Function[BaseModel], AsyncFunction[BaseModel]]] | None = None, + on_chunk: Callable[[str], Awaitable[None]] | None = None, + ) -> Union[ModelMessage, AnyDeferredResult]: + # Can return either completed ModelMessage OR suspension + ... +``` + +### 3. OpenAI Model Changes + +```python +async def generate_text(self, input, **kwargs) -> Union[ModelMessage, AnyDeferredResult]: + function_results = await self._execute_functions(input, functions) + + # Check if any function suspended + for result in function_results: + if isinstance(result, DeferredResult): + return result # Return suspension immediately + + # No suspensions - continue with normal LLM call + openai_messages = self._convert_messages(input, system, messages) + response = await self._client.chat.completions.create(...) + model_response = self._convert_response(response) + + if model_response.function_calls: + return await self.generate_text(model_response, **kwargs) + + return model_response + +async def _execute_functions(self, input: Message, functions) -> list[Union[FunctionMessage, DeferredResult]]: + results = [] + + if isinstance(input, ModelMessage) and input.function_calls: + for call in input.function_calls: + function = functions[call.name] + parsed_args = parse_function_arguments(function, call.arguments) + + if isinstance(function, AsyncFunction): + result = function.handler(parsed_args) + else: + result = function.handler(parsed_args) + + if inspect.isawaitable(result): + result = await result + + if isinstance(result, DeferredResult): + # Update call with suspension info + call.type = "async" + call.status = "suspended" + call.state = result.state + call.handler = result.handler + results.append(result) + break # Stop processing on suspension + else: + results.append(FunctionMessage(content=result, function_id=call.id)) + + return results +``` + +### 4. Memory Extensions + +```python +class Memory(Protocol): + # Existing methods + async def push(self, message: Message) -> None: ... + async def get_all(self) -> list[Message]: ... + + # New methods for suspension + async def store_suspended_call(self, call: FunctionCall) -> None: ... + async def get_suspended_calls(self) -> list[FunctionCall]: ... + async def clear_suspended_calls(self) -> None: ... + +class ListMemory: + def __init__(self): + self._messages: list[Message] = [] + self._suspended_calls: list[FunctionCall] = [] + + async def store_suspended_call(self, call: FunctionCall) -> None: + self._suspended_calls.append(call) + + async def get_suspended_calls(self) -> list[FunctionCall]: + return self._suspended_calls.copy() + + async def clear_suspended_calls(self) -> None: + self._suspended_calls.clear() +``` + +## User Experience Examples + +### Simple Interaction + +```python +# Setup +chat = ChatPrompt(model=openai_model, functions=[shopping_function]) + +# Initial call suspends +result = await chat.send("Add something to my cart") +print(result.response.handler.question) # "What additional item would you like to add?" + +# Resume automatically +result = await chat.send("socks") +print(result.response.content) # "I've added 'socks' to your cart with ['shoes', 'shirts']" +``` + +### Multi-Step Workflow + +```python +# Step 1: Suspension +result = await chat.send("Submit my expense report") +print(result.response.handler.prompt) # "Manager approval needed for $250 expense: 'Office supplies'" + +# Step 2: Still suspended +result = await chat.send("approved") +print(result.response.handler.prompt) # "Director approval needed for $250 expense" + +# Step 3: Completion +result = await chat.send("approved") +print(result.response.content) # "Expense 'Office supplies - $250' fully approved!" +``` + +## Implementation Plan: Human-in-the-Loop AsyncFunction + +### Phase 1: Core Types (function.py) +1. Add AskUserHandler and DeferredResult types +2. Extend FunctionCall with async fields +3. Create AsyncFunction protocol and dataclass +4. Update imports and type variables + +### Phase 2: Memory Support +5. Update Memory protocol with suspension methods +6. Implement ListMemory suspension support + +### Phase 3: Model Integration +7. Update AIModel protocol return type +8. Modify OpenAI models _execute_functions for DeferredResult +9. Update ChatPrompt to handle DeferredResult and resume + +### Phase 4: Testing +10. Create simple HITL test/example +11. Update __init__.py exports + +## Key Files to Modify + +- `packages/ai/src/microsoft/teams/ai/function.py` - Core types +- `packages/ai/src/microsoft/teams/ai/memory.py` - Memory protocol +- `packages/ai/src/microsoft/teams/ai/list_memory.py` - Memory implementation +- `packages/ai/src/microsoft/teams/ai/ai_model.py` - AIModel protocol +- `packages/ai/src/microsoft/teams/ai/chat_prompt.py` - ChatPrompt logic +- `packages/openai/src/microsoft/teams/openai/completions_model.py` - OpenAI implementation +- `packages/openai/src/microsoft/teams/openai/responses_chat_model.py` - OpenAI implementation +- `packages/ai/src/microsoft/teams/ai/__init__.py` - Exports + +## Backward Compatibility + +- All existing `Function` objects continue working unchanged +- No breaking changes to existing function handlers +- ChatPrompt can mix both sync and async functions +- Models handle both types transparently +- Existing memory implementations work (new methods optional) + +## Benefits + +1. **Type Safety**: Generic DeferredResult ensures handler/resumer type matching +2. **Clean Separation**: ChatPrompt handles conversation flow, models handle execution +3. **Simple Integration**: Minimal changes to existing codebase +4. **Flexible**: Supports various interaction patterns (user input, webhooks, timers) +5. **Natural UX**: Users continue normal conversation flow during suspensions + +--- + +## Appendix: Alternative Approaches Considered + +### A1. Suspend Callback Approach + +**Initial Design:** +```python +async def handler(params: AsyncParams[T], suspend: SuspendCallback) -> str: + if params.type == "init": + suspend(state_data) + elif params.type == "resumed": + # Handle resume +``` + +**Issues:** +- Complex AsyncParams type trying to handle both init and resume data +- Different data shapes (structured params vs simple user responses) +- Magic suspend callback felt indirect +- Harder to follow workflow logic across init/resume + +### A2. String-Based Handlers + +**Alternative Design:** +```python +@dataclass +class DeferredResult: + state: dict[str, Any] + handler: str # "AskUser", "GetApproval" + output: str # Display text +``` + +**Issues:** +- No type safety for handler types +- String typos ("AskUer" vs "AskUser") +- No IDE autocomplete support +- Hard to extend with new handler types + +### A3. Complex Recursion Control + +**Alternative Design:** +```python +async def generate_text(...) -> ModelMessage: + function_results, has_suspended = await self._execute_functions(...) + + if model_response.function_calls and not has_suspended: + return await self.generate_text(...) # Only recurse if no suspensions +``` + +**Issues:** +- Complex state tracking in models +- Need to modify ModelMessage structure +- Suspension state scattered across multiple places +- Harder to reason about control flow + +### A4. Explicit Resume API + +**Alternative Design:** +```python +# Suspend +result = await chat.send("Process payment") + +# Explicit resume method +resume_result = await chat.resume("yes, I approve") +``` + +**Issues:** +- Two different APIs for users to learn +- More complex state management +- Less natural conversation flow +- Users need to understand suspend/resume mechanics + +### A5. Models Return Complex Union Types + +**Alternative Design:** +```python +async def generate_text(...) -> Union[ModelMessage, DeferredResult, SuspendedState]: + # Multiple return types based on execution state +``` + +**Issues:** +- Complex type handling throughout system +- Unclear which type to expect when +- More difficult error handling +- Harder to extend with new states + +--- + +## Conclusion + +The chosen design balances simplicity, type safety, and clean integration. The generic DeferredResult approach provides strong typing while the simplified resume flow (AsyncFunction.resume → FunctionMessage → LLM) keeps the implementation straightforward and maintainable. \ No newline at end of file diff --git a/packages/ai/src/microsoft/teams/ai/__init__.py b/packages/ai/src/microsoft/teams/ai/__init__.py index dca604ac..ce3958cc 100644 --- a/packages/ai/src/microsoft/teams/ai/__init__.py +++ b/packages/ai/src/microsoft/teams/ai/__init__.py @@ -3,12 +3,22 @@ Licensed under the MIT License. """ +from . import plugins, utils from .agent import Agent from .ai_model import AIModel from .chat_prompt import ChatPrompt, ChatSendResult -from .function import Function, FunctionCall, FunctionHandler, FunctionHandlers, FunctionHandlerWithNoParams +from .function import ( + DeferredResult, + Function, + FunctionCall, + FunctionHandler, + FunctionHandlers, + FunctionHandlerWithNoParams, +) from .memory import ListMemory, Memory -from .message import FunctionMessage, Message, ModelMessage, SystemMessage, UserMessage +from .message import DeferredMessage, FunctionMessage, Message, ModelMessage, SystemMessage, UserMessage +from .plugin import AIPluginProtocol, BaseAIPlugin +from .utils import * # noqa: F401, F403 __all__ = [ "ChatSendResult", @@ -19,12 +29,18 @@ "ModelMessage", "SystemMessage", "FunctionMessage", + "DeferredMessage", "Function", "FunctionCall", + "DeferredResult", "Memory", "ListMemory", "AIModel", + "AIPluginProtocol", + "BaseAIPlugin", "FunctionHandler", "FunctionHandlerWithNoParams", "FunctionHandlers", ] +__all__.extend(utils.__all__) +__all__.extend(plugins.__all__) diff --git a/packages/ai/src/microsoft/teams/ai/ai_model.py b/packages/ai/src/microsoft/teams/ai/ai_model.py index 43489750..0f4504cc 100644 --- a/packages/ai/src/microsoft/teams/ai/ai_model.py +++ b/packages/ai/src/microsoft/teams/ai/ai_model.py @@ -9,7 +9,7 @@ from .function import Function from .memory import Memory -from .message import Message, ModelMessage, SystemMessage +from .message import DeferredMessage, Message, ModelMessage, SystemMessage class AIModel(Protocol): @@ -23,13 +23,13 @@ class AIModel(Protocol): async def generate_text( self, - input: Message, + input: Message | None, *, system: SystemMessage | None = None, memory: Memory | None = None, functions: dict[str, Function[BaseModel]] | None = None, on_chunk: Callable[[str], Awaitable[None]] | None = None, - ) -> ModelMessage: + ) -> ModelMessage | list[DeferredMessage]: """ Generate a text response from the AI model. diff --git a/packages/ai/src/microsoft/teams/ai/chat_prompt.py b/packages/ai/src/microsoft/teams/ai/chat_prompt.py index 49908d82..85c9e5dc 100644 --- a/packages/ai/src/microsoft/teams/ai/chat_prompt.py +++ b/packages/ai/src/microsoft/teams/ai/chat_prompt.py @@ -6,14 +6,16 @@ import inspect from dataclasses import dataclass from inspect import isawaitable +from logging import Logger from typing import Any, Awaitable, Callable, Dict, Optional, Self, TypeVar, Union, cast, overload +from microsoft.teams.common.logging import ConsoleLogger from pydantic import BaseModel from .ai_model import AIModel from .function import Function, FunctionHandler, FunctionHandlers, FunctionHandlerWithNoParams from .memory import Memory -from .message import Message, ModelMessage, SystemMessage, UserMessage +from .message import DeferredMessage, FunctionMessage, Message, ModelMessage, SystemMessage, UserMessage from .plugin import AIPluginProtocol T = TypeVar("T", bound=BaseModel) @@ -28,7 +30,8 @@ class ChatSendResult: calls and plugin processing have been completed. """ - response: ModelMessage # Final model response after processing + response: ModelMessage | None # Final model response after processing + is_deferred: bool = False class ChatPrompt: @@ -45,6 +48,9 @@ def __init__( *, functions: list[Function[Any]] | None = None, plugins: list[AIPluginProtocol] | None = None, + memory: Memory | None = None, + logger: Logger | None = None, + instructions: str | SystemMessage | None = None, ): """ Initialize ChatPrompt with model and optional functions/plugins. @@ -53,10 +59,16 @@ def __init__( model: AI model implementation for text generation functions: Optional list of functions the model can call plugins: Optional list of plugins for extending functionality + memory: Optional memory for conversation context and deferred state + logger: Optional logger for debugging and monitoring + instructions: Optional default system instructions for the model """ self.model = model self.functions: dict[str, Function[Any]] = {func.name: func for func in functions} if functions else {} self.plugins: list[AIPluginProtocol] = plugins or [] + self.memory = memory + self.logger = logger or ConsoleLogger().create_logger("@teams/ai/chat_prompt") + self.instructions = instructions @overload def with_function(self, function: Function[T]) -> Self: ... @@ -134,9 +146,136 @@ def with_plugin(self, plugin: AIPluginProtocol) -> Self: self.plugins.append(plugin) return self + async def requires_resuming(self) -> bool: + """ + Check if there are any deferred functions that need resuming. + + Returns: + True if there are DeferredMessage objects in memory that need resuming + """ + if not self.memory: + return False + + messages = await self.memory.get_all() + return any(isinstance(msg, DeferredMessage) for msg in messages) + + async def resolve_deferred(self, activity: Any) -> list[str]: + """ + Resolve deferred functions with the provided activity input. + + Only attempts to resolve deferred functions whose resumers can handle + the provided activity type (determined by can_handle method). + + Args: + activity: Activity data to use for resolving deferred functions + + Returns: + List of resolution results from successfully resolved functions + """ + if not self.memory: + return [] + + messages = await self.memory.get_all() + deferred_messages = [msg for msg in messages if isinstance(msg, DeferredMessage)] + + if not deferred_messages: + return [] + + results: list[str] = [] + updated_messages = messages.copy() # Work with a copy + + for i, msg in enumerate(updated_messages): + if not isinstance(msg, DeferredMessage): + continue + + # Try plugins first, then fall back to built-in resumer + result = await self._try_resolve_with_plugins(msg, activity) + if result is None: + result = await self._try_resolve_with_builtin_resumer(msg, activity) + + if result is not None: + updated_messages[i] = FunctionMessage(content=result, function_id=msg.function_id) + results.append(result) + + # Update memory with resolved messages + if results: # Only update if we actually resolved something + await self.memory.set_all(updated_messages) + + return results + + async def _try_resolve_with_plugins(self, msg: DeferredMessage, activity: Any) -> str | None: + """ + Try to resolve a deferred message using plugins. + + Args: + msg: The deferred message to resolve + activity: Activity data for resolution + + Returns: + Result string if a plugin handled it, None otherwise + """ + for plugin in self.plugins: + result = await plugin.on_resume(msg.function_name, activity, msg.deferred_result.state) + if result is not None: + return result + return None + + async def _try_resolve_with_builtin_resumer(self, msg: DeferredMessage, activity: Any) -> str | None: + """ + Try to resolve a deferred message using the built-in resumer. + + Args: + msg: The deferred message to resolve + activity: Activity data for resolution + + Returns: + Result string if resolved successfully, None if skipped, raises on error + """ + resumer_name = msg.function_name + associated_func = self.functions.get(resumer_name) + + if not associated_func or associated_func.resumer is None: + raise ValueError(f"Expected a resumer for {resumer_name} but chat prompt was not set up with one") + + # Check if the resumer can handle this type of activity + if not associated_func.resumer.can_handle(activity): + return None # Skip this deferred function + + try: + # Call the resumer with the activity and saved state + result = associated_func.resumer(activity, msg.deferred_result.state) + if isawaitable(result): + result = await result + return result + + except Exception as e: + # Return error message instead of raising + return f"Error resolving {resumer_name}: {str(e)}" + + async def resume(self, activity: Any) -> ChatSendResult: + """ + Resume deferred functions with the provided activity input. + + If all deferred functions are resolved, automatically continues with + normal chat processing using the activity text as input. + + Args: + activity: Activity data to use for resolving deferred functions + + Returns: + ChatSendResult - either indicating still deferred or containing the chat response + """ + await self.resolve_deferred(activity) + + # If there are still deferred functions pending, return early + if await self.requires_resuming(): + return ChatSendResult(response=None, is_deferred=True) + + return await self.send(input=None) + async def send( self, - input: str | Message, + input: str | Message | None, *, memory: Memory | None = None, on_chunk: Callable[[str], Awaitable[None]] | Callable[[str], None] | None = None, @@ -158,11 +297,18 @@ async def send( if isinstance(input, str): input = UserMessage(content=input) + # Use constructor instructions as default if none provided + if instructions is None: + instructions = self.instructions + # Convert string instructions to SystemMessage if isinstance(instructions, str): instructions = SystemMessage(content=instructions) - current_input = await self._run_before_send_hooks(input) + if input is not None: + current_input = await self._run_before_send_hooks(input) + else: + current_input = None current_system_message = await self._run_build_instructions_hooks(instructions) wrapped_functions = await self._build_wrapped_functions() @@ -176,10 +322,12 @@ async def on_chunk_fn(chunk: str): response = await self.model.generate_text( current_input, system=current_system_message, - memory=memory, + memory=memory or self.memory, functions=wrapped_functions, on_chunk=on_chunk_fn if on_chunk else None, ) + if isinstance(response, list): + return ChatSendResult(response=None, is_deferred=True) current_response = await self._run_after_send_hooks(response) @@ -283,7 +431,9 @@ async def _build_wrapped_functions(self) -> dict[str, Function[BaseModel]] | Non name=func.name, description=func.description, parameter_schema=func.parameter_schema, - handler=self._wrap_function_handler(func.handler, func.name), + handler=self._wrap_function_handler(cast(FunctionHandler[BaseModel], func.handler), func.name) + if func.resumer is None + else func.handler, ) return wrapped_functions diff --git a/packages/ai/src/microsoft/teams/ai/function.py b/packages/ai/src/microsoft/teams/ai/function.py index beafcc1b..ada67971 100644 --- a/packages/ai/src/microsoft/teams/ai/function.py +++ b/packages/ai/src/microsoft/teams/ai/function.py @@ -3,12 +3,13 @@ Licensed under the MIT License. """ -from dataclasses import dataclass, field -from typing import Any, Awaitable, Dict, Generic, Protocol, TypeVar, Union +from dataclasses import dataclass +from typing import Any, Awaitable, Dict, Generic, Literal, Protocol, TypeVar, Union from pydantic import BaseModel Params = TypeVar("Params", bound=BaseModel, contravariant=True) +ResumableData = TypeVar("ResumableData") """ Type variable for function parameter schemas. @@ -38,6 +39,61 @@ def __call__(self, params: Params) -> Union[str, Awaitable[str]]: ... +class DeferredFunctionResumer(Generic[Params, ResumableData]): + """ + The resumable function returns the actual string + """ + + def can_handle(self, activity: Any) -> bool: + """ + Check if this resumer can handle the given activity input. + + Args: + activity: The activity data to check + + Returns: + True if this resumer can process the activity, False otherwise + """ + ... + + def __call__(self, params: Params, resumableData: ResumableData) -> Awaitable[str]: ... + + +@dataclass +class DeferredResult: + """ + Represents a deferred result that can be resumed later on + """ + + state: dict[str, Any] + type: Literal["deferred"] = "deferred" + + +@dataclass +class FunctionCall: + """ + Represents a function call request from an AI model. + + Contains the function name, unique call ID, and parsed arguments + that will be passed to the function handler. + """ + + id: str # Unique identifier for this function call + name: str # Name of the function to call + arguments: dict[str, Any] # Parsed arguments for the function + + +class DeferredFunctionHandler(Protocol[Params]): + """ + The Deferred Function handler defers the job and returns the name + of the resumable function + Returns the name of the resumable function, and the parameters to save + state + """ + + def __call__(self, params: Params) -> Awaitable[DeferredResult]: ... + + class FunctionHandlerWithNoParams(Protocol): """ Protocol for function handlers that can be called by AI models. @@ -81,19 +137,8 @@ class Function(Generic[Params]): name: str # Unique identifier for the function description: str # Human-readable description of what the function does - parameter_schema: Union[type[Params], Dict[str, Any], None] # Pydantic model class, JSON schema dict, or None - handler: Union[FunctionHandler[Params], FunctionHandlerWithNoParams] # Function implementation (sync or async) - - -@dataclass -class FunctionCall: - """ - Represents a function call request from an AI model. - - Contains the function name, unique call ID, and parsed arguments - that will be passed to the function handler if any. - """ - - id: str # Unique identifier for this function call - name: str # Name of the function to call - arguments: dict[str, Any] = field(default_factory=dict[str, Any]) # Parsed arguments for the function + parameter_schema: Union[type[Params], Dict[str, Any], None] # Pydantic model class or JSON schema dict + handler: ( + FunctionHandler[Params] | FunctionHandlerWithNoParams | DeferredFunctionHandler[Params] + ) # Function implementation (sync or async) + resumer: DeferredFunctionResumer[Params, Any] | None = None # Optional resumer for deferred functions diff --git a/packages/ai/src/microsoft/teams/ai/message.py b/packages/ai/src/microsoft/teams/ai/message.py index 6981d09b..c1e43ce9 100644 --- a/packages/ai/src/microsoft/teams/ai/message.py +++ b/packages/ai/src/microsoft/teams/ai/message.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Literal, Union -from .function import FunctionCall +from .function import DeferredResult, FunctionCall @dataclass @@ -64,7 +64,19 @@ class FunctionMessage: role: Literal["function"] = "function" # Message type identifier -Message = Union[UserMessage, ModelMessage, SystemMessage, FunctionMessage] +@dataclass +class DeferredMessage: + """ + Represents a function call that is deferred + """ + + deferred_result: DeferredResult + function_name: str + function_id: str + content: None = None + + +Message = Union[UserMessage, ModelMessage, SystemMessage, FunctionMessage, DeferredMessage] """ Union type representing any message in a conversation. diff --git a/packages/ai/src/microsoft/teams/ai/plugin.py b/packages/ai/src/microsoft/teams/ai/plugin.py index 82194610..b324090b 100644 --- a/packages/ai/src/microsoft/teams/ai/plugin.py +++ b/packages/ai/src/microsoft/teams/ai/plugin.py @@ -4,7 +4,7 @@ """ from abc import abstractmethod -from typing import Optional, Protocol, TypeVar, runtime_checkable +from typing import Any, Optional, Protocol, TypeVar, runtime_checkable from pydantic import BaseModel @@ -111,6 +111,20 @@ async def on_build_instructions(self, instructions: SystemMessage | None) -> Sys """ ... + async def on_resume(self, function_name: str, activity: Any, state: dict[str, Any]) -> str | None: + """ + Called when ChatPrompt is attempting to resume a deferred function. + + Args: + function_name: Name of the function that was deferred + activity: The activity data to use for resolving + state: The state that was saved when function was deferred + + Returns: + Result string if this plugin handled the resuming, None otherwise + """ + ... + class BaseAIPlugin: """ @@ -165,3 +179,7 @@ async def on_build_functions(self, functions: list[Function[BaseModel]]) -> list async def on_build_instructions(self, instructions: SystemMessage | None) -> SystemMessage | None: """Modify the system message before sending to model.""" return instructions + + async def on_resume(self, function_name: str, activity: Any, state: dict[str, Any]) -> str | None: + """Called when ChatPrompt is attempting to resume a deferred function.""" + return None diff --git a/packages/ai/src/microsoft/teams/ai/utils/__init__.py b/packages/ai/src/microsoft/teams/ai/utils/__init__.py new file mode 100644 index 00000000..7d06c498 --- /dev/null +++ b/packages/ai/src/microsoft/teams/ai/utils/__init__.py @@ -0,0 +1,8 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +from .function_utils import execute_function, get_function_schema, parse_function_arguments + +__all__ = ["get_function_schema", "parse_function_arguments", "execute_function"] diff --git a/packages/openai/src/microsoft/teams/openai/function_utils.py b/packages/ai/src/microsoft/teams/ai/utils/function_utils.py similarity index 69% rename from packages/openai/src/microsoft/teams/openai/function_utils.py rename to packages/ai/src/microsoft/teams/ai/utils/function_utils.py index 012647a4..d60b83e5 100644 --- a/packages/openai/src/microsoft/teams/openai/function_utils.py +++ b/packages/ai/src/microsoft/teams/ai/utils/function_utils.py @@ -3,11 +3,19 @@ Licensed under the MIT License. """ -from typing import Any, Dict, Optional +import inspect +from typing import Any, Dict, Optional, cast -from microsoft.teams.ai import Function from pydantic import BaseModel, ConfigDict, create_model +from ..function import ( + DeferredFunctionHandler, + DeferredResult, + Function, + FunctionHandler, + FunctionHandlerWithNoParams, +) + def get_function_schema(func: Function[Any]) -> Dict[str, Any]: """ @@ -59,3 +67,20 @@ def parse_function_arguments(func: Function[Any], arguments: Dict[str, Any]) -> else: # For Pydantic model schemas, parse normally return func.parameter_schema(**arguments) + + +async def execute_function(function: Function[Any], arguments: Dict[str, Any]) -> str | DeferredResult: + parsed_args = parse_function_arguments(function, arguments) + if parsed_args: + # Handle both sync and async function handlers + handler = cast(FunctionHandler[BaseModel] | DeferredFunctionHandler[BaseModel], function.handler) + result = handler(parsed_args) + else: + handler = cast(FunctionHandlerWithNoParams, function.handler) + result = handler() + + if inspect.isawaitable(result): + fn_res = await result + else: + fn_res = result + return fn_res diff --git a/packages/ai/tests/test_chat_prompt.py b/packages/ai/tests/test_chat_prompt.py index 0f39dab7..94c496c5 100644 --- a/packages/ai/tests/test_chat_prompt.py +++ b/packages/ai/tests/test_chat_prompt.py @@ -139,7 +139,7 @@ async def test_string_input_conversion(self, mock_model: MockAIModel) -> None: result = await prompt.send("Hello world") assert isinstance(result, ChatSendResult) - assert result.response.content == "GENERATED - Hello world" + assert result.response and result.response.content == "GENERATED - Hello world" @pytest.mark.asyncio async def test_memory_updates(self) -> None: @@ -182,6 +182,7 @@ async def test_function_handler_execution(self, mock_function_handler: Mock) -> result = await prompt.send("Call the function") # Verify the function call is in the response + assert isinstance(result.response, ModelMessage) assert result.response.function_calls is not None assert len(result.response.function_calls) == 1 assert result.response.function_calls[0].name == "test_function" @@ -223,10 +224,12 @@ async def test_full_conversation_flow(self, test_function: Function[MockFunction # First exchange result1 = await prompt.send("Hello", memory=memory) + assert isinstance(result1.response, ModelMessage) assert result1.response.content == "GENERATED - Hello" # Second exchange result2 = await prompt.send("How are you?", memory=memory) + assert isinstance(result2.response, ModelMessage) assert result2.response.content == "GENERATED - How are you?" # Verify memory contains complete conversation history @@ -284,16 +287,19 @@ async def test_different_message_types(self, mock_model: MockAIModel) -> None: # String input result1 = await prompt.send("String input") + assert isinstance(result1.response, ModelMessage) assert result1.response.content == "GENERATED - String input" # UserMessage input user_msg = UserMessage(content="User message") result2 = await prompt.send(user_msg) + assert isinstance(result2.response, ModelMessage) assert result2.response.content == "GENERATED - User message" # ModelMessage input (for function calling scenarios) model_msg = ModelMessage(content="Model message", function_calls=None) result3 = await prompt.send(model_msg) + assert isinstance(result3.response, ModelMessage) assert result3.response.content == "GENERATED - Model message" @pytest.mark.asyncio @@ -334,6 +340,7 @@ def handler_no_params() -> str: # Verify both work in send result = await prompt.send("Test message") + assert isinstance(result.response, ModelMessage) assert result.response.content == "GENERATED - Test message" @@ -431,6 +438,7 @@ async def test_on_before_send_hook(self, mock_model: MockAIModel) -> None: assert mock_model.last_input is not None assert mock_model.last_input.content == "MODIFIED: Original message" # Verify the response reflects the modified input + assert isinstance(result.response, ModelMessage) assert result.response.content == "GENERATED - MODIFIED: Original message" @pytest.mark.asyncio @@ -443,6 +451,7 @@ async def test_on_after_send_hook(self, mock_model: MockAIModel) -> None: result = await prompt.send("Test message") assert plugin.after_send_called + assert isinstance(result.response, ModelMessage) assert result.response.content == "RESPONSE_MODIFIED: GENERATED - Test message" @pytest.mark.asyncio @@ -493,6 +502,7 @@ async def test_function_call_hooks(self, mock_function_handler: Mock) -> None: # Verify after hook was called and modified result assert len(plugin.after_function_called) == 1 assert plugin.after_function_called[0][0] == "test_function" + assert isinstance(result.response, ModelMessage) assert result.response.content is not None assert "FUNCTION_MODIFIED: Function executed successfully" in result.response.content @@ -536,6 +546,7 @@ async def test_multiple_plugins_execution_order(self, mock_model: MockAIModel) - assert plugin2.after_send_called # Input should be modified by both plugins in order + assert isinstance(result.response, ModelMessage) assert result.response.content == "SECOND_RESP: FIRST_RESP: GENERATED - SECOND: FIRST: Original" @pytest.mark.asyncio @@ -557,6 +568,7 @@ async def on_after_send(self, response: ModelMessage) -> ModelMessage | None: result = await prompt.send("Test message") # Should be unchanged since plugin returned None + assert isinstance(result.response, ModelMessage) assert result.response.content == "GENERATED - Test message" @pytest.mark.asyncio @@ -568,6 +580,8 @@ async def test_empty_plugin_list_maintains_compatibility(self, mock_model: MockA result_with = await prompt_with_plugins.send("Test message") result_without = await prompt_without_plugins.send("Test message") + assert isinstance(result_with.response, ModelMessage) + assert isinstance(result_without.response, ModelMessage) assert result_with.response.content == result_without.response.content @pytest.mark.asyncio @@ -592,6 +606,7 @@ async def test_plugin_with_async_function_handler(self, mock_function_handler: M # Verify function was called and result was modified by plugin assert len(plugin.before_function_called) == 1 assert len(plugin.after_function_called) == 1 + assert isinstance(result.response, ModelMessage) assert result.response.content is not None assert "ASYNC_MODIFIED: Function executed successfully" in result.response.content @@ -621,6 +636,7 @@ async def test_base_plugin_default_implementations(self, mock_model: MockAIModel # Should work without any issues using default implementations result = await prompt.send("Test with base plugin") + assert isinstance(result.response, ModelMessage) assert result.response.content == "GENERATED - Test with base plugin" # Test with functions too @@ -631,6 +647,7 @@ def handler(params: MockFunctionParams) -> str: prompt_with_func = ChatPrompt(mock_model, functions=[test_function], plugins=[base_plugin]) result2 = await prompt_with_func.send("Test with function") + assert isinstance(result2.response, ModelMessage) assert result2.response.content == "GENERATED - Test with function" @pytest.mark.asyncio @@ -673,6 +690,7 @@ async def test_comprehensive_plugin_behavior_verification(self, mock_function_ha assert "test_function" in mock_model.last_functions # Verify final response includes all modifications + assert isinstance(result.response, ModelMessage) assert result.response.content is not None assert "RESP_MOD:" in result.response.content assert "FUNC_MOD: Function executed successfully" in result.response.content diff --git a/packages/openai/tests/test_function_utils.py b/packages/ai/tests/test_function_utils.py similarity index 98% rename from packages/openai/tests/test_function_utils.py rename to packages/ai/tests/test_function_utils.py index dc202983..faf7005c 100644 --- a/packages/openai/tests/test_function_utils.py +++ b/packages/ai/tests/test_function_utils.py @@ -8,8 +8,7 @@ from typing import Optional import pytest -from microsoft.teams.ai import Function -from microsoft.teams.openai.function_utils import get_function_schema, parse_function_arguments +from microsoft.teams.ai import Function, get_function_schema, parse_function_arguments from pydantic import BaseModel, ValidationError diff --git a/packages/api/tests/unit/test_user_client.py b/packages/api/tests/unit/test_user_client.py index 7a1cc838..6092b180 100644 --- a/packages/api/tests/unit/test_user_client.py +++ b/packages/api/tests/unit/test_user_client.py @@ -3,6 +3,7 @@ Licensed under the MIT License. """ +# pyright: basic import pytest from microsoft.teams.api.clients.user import UserClient from microsoft.teams.api.clients.user.params import ( diff --git a/packages/common/tests/test_client.py b/packages/common/tests/test_client.py index f24e253a..36505075 100644 --- a/packages/common/tests/test_client.py +++ b/packages/common/tests/test_client.py @@ -3,6 +3,8 @@ Licensed under the MIT License. """ +# pyright: basic + import httpx import pytest from microsoft.teams.common.http import Client, ClientOptions, Interceptor diff --git a/packages/common/tests/test_event_emitter.py b/packages/common/tests/test_event_emitter.py index ad6a67c2..6a1ce9b5 100644 --- a/packages/common/tests/test_event_emitter.py +++ b/packages/common/tests/test_event_emitter.py @@ -3,6 +3,8 @@ Licensed under the MIT License. """ +# pyright: basic + import asyncio from unittest.mock import Mock diff --git a/packages/common/tests/test_logging_filter.py b/packages/common/tests/test_logging_filter.py index 0f99409e..6b812101 100644 --- a/packages/common/tests/test_logging_filter.py +++ b/packages/common/tests/test_logging_filter.py @@ -3,6 +3,8 @@ Licensed under the MIT License. """ +# pyright: basic + import logging from unittest.mock import MagicMock diff --git a/packages/common/tests/test_logging_formatter.py b/packages/common/tests/test_logging_formatter.py index b30923f0..bd9b5737 100644 --- a/packages/common/tests/test_logging_formatter.py +++ b/packages/common/tests/test_logging_formatter.py @@ -3,6 +3,8 @@ Licensed under the MIT License. """ +# pyright: basic + import logging from typing import Collection, Union diff --git a/packages/openai/src/microsoft/teams/openai/completions_model.py b/packages/openai/src/microsoft/teams/openai/completions_model.py index ad531496..70ec2e67 100644 --- a/packages/openai/src/microsoft/teams/openai/completions_model.py +++ b/packages/openai/src/microsoft/teams/openai/completions_model.py @@ -3,7 +3,6 @@ Licensed under the MIT License. """ -import inspect import json from dataclasses import dataclass from typing import Any, Awaitable, Callable, TypedDict, cast @@ -19,8 +18,11 @@ ModelMessage, SystemMessage, UserMessage, + get_function_schema, ) -from microsoft.teams.ai.function import FunctionHandler, FunctionHandlerWithNoParams +from microsoft.teams.ai.function import DeferredResult +from microsoft.teams.ai.message import DeferredMessage +from microsoft.teams.ai.utils.function_utils import execute_function from microsoft.teams.openai.common import OpenAIBaseModel from pydantic import BaseModel @@ -39,8 +41,6 @@ ChatCompletionUserMessageParam, ) -from .function_utils import get_function_schema, parse_function_arguments - class _ToolCallData(TypedDict): """ @@ -68,13 +68,13 @@ class OpenAICompletionsAIModel(OpenAIBaseModel, AIModel): async def generate_text( self, - input: Message, + input: Message | None, *, system: SystemMessage | None = None, memory: Memory | None = None, functions: dict[str, Function[BaseModel]] | None = None, on_chunk: Callable[[str], Awaitable[None]] | None = None, - ) -> ModelMessage: + ) -> ModelMessage | list[DeferredMessage]: """ Generate text using OpenAI Chat Completions API. @@ -97,28 +97,36 @@ async def generate_text( if memory is None: memory = ListMemory() - # Execute any pending function calls first - function_results = await self._execute_functions(input, functions) - # Get conversation history from memory (make a copy to avoid modifying memory's internal state) messages = list(await memory.get_all()) + + # Execute any pending function calls first + function_results = await self._execute_functions(input, messages, functions) self.logger.debug(f"Retrieved {len(messages)} messages from memory, {len(function_results)} function results") # Push current input to memory - await memory.push(input) + if input is not None: + await memory.push(input) # Push function results to memory and add to messages + deferred_messages: list[DeferredMessage] = [] if function_results: # Add the original ModelMessage with function_calls to messages first - messages.append(input) + if input is not None: + messages.append(input) for result in function_results: await memory.push(result) messages.append(result) + if isinstance(result, DeferredMessage): + deferred_messages.append(result) # Don't add input again at the end - Order matters here! input_to_send = None else: input_to_send = input + if len(deferred_messages) > 0: + return deferred_messages + # Convert messages to OpenAI format openai_messages = self._convert_messages(input_to_send, system, messages) self.logger.debug(f"Converted to {len(openai_messages)} OpenAI messages") @@ -153,35 +161,37 @@ async def generate_text( return model_response async def _execute_functions( - self, input: Message, functions: dict[str, Function[BaseModel]] | None - ) -> list[FunctionMessage]: + self, input: Message | None, memory_messages: list[Message], functions: dict[str, Function[BaseModel]] | None + ) -> list[FunctionMessage | DeferredMessage]: """Execute any pending function calls in the input message.""" - function_results: list[FunctionMessage] = [] + function_results: list[FunctionMessage | DeferredMessage] = [] if isinstance(input, ModelMessage) and input.function_calls: # Execute any pending function calls self.logger.debug(f"Executing {len(input.function_calls)} function calls") for call in input.function_calls: + existing_function_result = next( + ( + message + for message in memory_messages + if isinstance(message, FunctionMessage) and message.function_id == call.id + ), + None, + ) + if existing_function_result is None: + self.logger.debug(f"{call.name} already called. Skipping exeuction") if functions and call.name in functions: function = functions[call.name] try: # Parse arguments using utility function - parsed_args = parse_function_arguments(function, call.arguments) - if parsed_args: - # Handle both sync and async function handlers - handler = cast(FunctionHandler[BaseModel], function.handler) - result = handler(parsed_args) + fn_res = await execute_function(function, call.arguments) + if isinstance(fn_res, DeferredResult): + function_results.append( + DeferredMessage(deferred_result=fn_res, function_name=call.name, function_id=call.id) + ) else: - handler = cast(FunctionHandlerWithNoParams, function.handler) - result = handler() - - if inspect.isawaitable(result): - fn_res = await result - else: - fn_res = result - - # Create function result message - function_results.append(FunctionMessage(content=fn_res, function_id=call.id)) + # Create function result message + function_results.append(FunctionMessage(content=fn_res, function_id=call.id)) except Exception as e: self.logger.error(e) # Handle function execution errors @@ -264,37 +274,43 @@ def _convert_messages( return openai_messages def _convert_message_to_openai_format(self, message: Message) -> ChatCompletionMessageParam: - if isinstance( - message, - UserMessage, - ): - return ChatCompletionUserMessageParam(role=message.role, content=message.content) - if isinstance(message, SystemMessage): - return ChatCompletionSystemMessageParam(role=message.role, content=message.content) - - elif isinstance(message, FunctionMessage): - return ChatCompletionToolMessageParam( - role="tool", - content=message.content or [], - tool_call_id=message.function_id, - ) - elif isinstance(message, ModelMessage): # pyright: ignore [reportUnnecessaryIsInstance] - if message.function_calls: - tool_calls = [ - ChatCompletionMessageFunctionToolCallParam( - id=call.id, - function={"name": call.name, "arguments": json.dumps(call.arguments)}, - type="function", + match message: + case UserMessage(): + return ChatCompletionUserMessageParam(role=message.role, content=message.content) + case SystemMessage(): + return ChatCompletionSystemMessageParam(role=message.role, content=message.content) + case FunctionMessage(): + return ChatCompletionToolMessageParam( + role="tool", + content=message.content or [], + tool_call_id=message.function_id, + ) + case ModelMessage(): + if message.function_calls: + tool_calls = [ + ChatCompletionMessageFunctionToolCallParam( + id=call.id, + function={"name": call.name, "arguments": json.dumps(call.arguments)}, + type="function", + ) + for call in message.function_calls + ] + else: + # we need to do this cast because Completions expects tool_calls to be >= 1, + # but the type is not Optional + tool_calls = cast(list[ChatCompletionMessageFunctionToolCallParam], None) + return ChatCompletionAssistantMessageParam( + role="assistant", content=message.content, tool_calls=tool_calls + ) + case DeferredMessage(): + raise ValueError( + ( + "A deferred_message should not be sent to OpenAI. It needs to be resolved " + "and converted to a FunctionMessage." ) - for call in message.function_calls - ] - else: - # we need to do this cast because Completions expects tool_calls to be >= 1, - # but the type is not Optional - tool_calls = cast(list[ChatCompletionMessageFunctionToolCallParam], None) - return ChatCompletionAssistantMessageParam(role="assistant", content=message.content, tool_calls=tool_calls) - else: - raise Exception(f"Message {message.role} not supported") + ) + case _: + raise Exception(f"Message {message.role} not supported") def _convert_functions(self, functions: dict[str, Function[BaseModel]]) -> list[ChatCompletionToolUnionParam]: function_values = functions.values() diff --git a/packages/openai/src/microsoft/teams/openai/responses_chat_model.py b/packages/openai/src/microsoft/teams/openai/responses_chat_model.py index 5ffd8b4d..c5f64f86 100644 --- a/packages/openai/src/microsoft/teams/openai/responses_chat_model.py +++ b/packages/openai/src/microsoft/teams/openai/responses_chat_model.py @@ -10,6 +10,7 @@ from microsoft.teams.ai import ( AIModel, + DeferredMessage, Function, FunctionCall, FunctionHandler, @@ -21,6 +22,8 @@ ModelMessage, SystemMessage, UserMessage, + get_function_schema, + parse_function_arguments, ) from pydantic import BaseModel @@ -40,7 +43,6 @@ ) from .common import OpenAIBaseModel -from .function_utils import get_function_schema, parse_function_arguments @dataclass @@ -57,13 +59,13 @@ class OpenAIResponsesAIModel(OpenAIBaseModel, AIModel): async def generate_text( self, - input: Message, + input: Message | None, *, system: SystemMessage | None = None, memory: Memory | None = None, functions: dict[str, Function[BaseModel]] | None = None, on_chunk: Callable[[str], Awaitable[None]] | None = None, - ) -> ModelMessage: + ) -> ModelMessage | list[DeferredMessage]: """ Generate text using OpenAI Responses API. @@ -95,13 +97,13 @@ async def generate_text( async def _send_stateful( self, - input: Message, + input: Message | None, system: SystemMessage | None, memory: Memory, functions: dict[str, Function[BaseModel]] | None, on_chunk: Callable[[str], Awaitable[None]] | None, function_results: list[FunctionMessage], - ) -> ModelMessage: + ) -> ModelMessage | list[DeferredMessage]: """Handle stateful conversation using OpenAI Responses API state management.""" # Get response IDs from memory - OpenAI manages conversation state messages = list(await memory.get_all()) @@ -163,21 +165,22 @@ async def _send_stateful( async def _send_stateless( self, - input: Message, + input: Message | None, system: SystemMessage | None, memory: Memory, functions: dict[str, Function[BaseModel]] | None, on_chunk: Callable[[str], Awaitable[None]] | None, function_results: list[FunctionMessage], - ) -> ModelMessage: + ) -> ModelMessage | list[DeferredMessage]: """Handle stateless conversation using standard OpenAI API pattern.""" # Get conversation history from memory (make a copy to avoid modifying memory's internal state) messages = list(await memory.get_all()) self.logger.debug(f"Retrieved {len(messages)} messages from memory") - # Push current input to memory - await memory.push(input) - messages.append(input) + if input: + # Push current input to memory + await memory.push(input) + messages.append(input) # Push function results to memory and add to messages if function_results: @@ -229,7 +232,7 @@ async def _send_stateless( return model_response async def _execute_functions( - self, input: Message, functions: dict[str, Function[BaseModel]] | None + self, input: Message | None, functions: dict[str, Function[BaseModel]] | None ) -> list[FunctionMessage]: """Execute any pending function calls in the input message.""" function_results: list[FunctionMessage] = [] diff --git a/packages/openai/tests/test_openai_completions_model.py b/packages/openai/tests/test_openai_completions_model.py index 2da098a2..26063486 100644 --- a/packages/openai/tests/test_openai_completions_model.py +++ b/packages/openai/tests/test_openai_completions_model.py @@ -76,6 +76,7 @@ async def test_generate_text_basic_message( result = await model.generate_text(input_msg) # Assertions + assert isinstance(result, ModelMessage) assert result.content == "Hello, world!" assert result.function_calls is None diff --git a/tests/ai-test/src/handlers/function_calling.py b/tests/ai-test/src/handlers/function_calling.py index 141d0056..1570b66e 100644 --- a/tests/ai-test/src/handlers/function_calling.py +++ b/tests/ai-test/src/handlers/function_calling.py @@ -68,7 +68,7 @@ async def handle_pokemon_search(model: AIModel, ctx: ActivityContext[MessageActi input=ctx.activity.text, instructions="You are a helpful assistant that can look up Pokemon for the user." ) - if chat_result.response.content: + if chat_result.response and chat_result.response.content: message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() await ctx.send(message) else: @@ -129,7 +129,7 @@ async def handle_multiple_functions(model: AIModel, ctx: ActivityContext[Message ), ) - if chat_result.response.content: + if chat_result.response and chat_result.response.content: message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() await ctx.send(message) else: diff --git a/tests/ai-test/src/handlers/memory_management.py b/tests/ai-test/src/handlers/memory_management.py index 1f10b9c1..368cd117 100644 --- a/tests/ai-test/src/handlers/memory_management.py +++ b/tests/ai-test/src/handlers/memory_management.py @@ -39,7 +39,7 @@ async def handle_stateful_conversation(model: AIModel, ctx: ActivityContext[Mess input=ctx.activity.text, instructions="You are a helpful assistant that remembers our previous conversation." ) - if chat_result.response.content: + if chat_result.response and chat_result.response.content: message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() await ctx.send(message) else: diff --git a/tests/ai-test/src/main.py b/tests/ai-test/src/main.py index e1ae3c23..54a03e04 100644 --- a/tests/ai-test/src/main.py +++ b/tests/ai-test/src/main.py @@ -68,7 +68,7 @@ async def handle_simple_chat(ctx: ActivityContext[MessageActivity]): input=ctx.activity.text, instructions="You are a friendly assistant who talks like a pirate" ) - if chat_result.response.content: + if chat_result.response and chat_result.response.content: message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() await ctx.send(message) @@ -107,7 +107,7 @@ async def handle_streaming(ctx: ActivityContext[MessageActivity]): if hasattr(ctx.activity.conversation, "is_group") and ctx.activity.conversation.is_group: # Group chat - send final response - if chat_result.response.content: + if chat_result.response and chat_result.response.content: message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() await ctx.send(message) else: @@ -167,7 +167,7 @@ async def handle_feedback_demo(ctx: ActivityContext[MessageActivity]): input="Tell me a short joke", instructions="You are a comedian. Keep responses brief and funny." ) - if chat_result.response.content: + if chat_result.response and chat_result.response.content: # Create message with feedback enabled and initialize storage message = MessageActivityInput(text=chat_result.response.content).add_ai_generated().add_feedback() sent_message = await ctx.send(message) diff --git a/tests/defferred_ai/README.md b/tests/defferred_ai/README.md new file mode 100644 index 00000000..da3b70c7 --- /dev/null +++ b/tests/defferred_ai/README.md @@ -0,0 +1,82 @@ +# Deferred AI Test + +Test application demonstrating approval workflow using `ApprovalPlugin`. + +## What This Demonstrates + +This test shows how to use the `ApprovalPlugin` to wrap functions that require human approval before execution. + +### How It Works + +1. **User asks to buy stocks**: "Buy 10 shares of MSFT" +2. **AI calls the function**: The AI model calls `buy_stock(stock="MSFT", quantity=10)` +3. **Plugin intercepts**: ApprovalPlugin wraps the function and defers execution +4. **Approval requested**: User sees approval request with function details +5. **User responds**: "yes" or "no" +6. **Plugin resumes**: + - If approved → executes original function and returns result + - If denied → returns cancellation message + +## Usage + +```bash +# Start the app +python src/main.py + +# In chat, ask to buy stocks +> Buy 10 shares of MSFT + +# You'll see approval request +> Approval Required +> Function: buy_stock +> Parameters: {'stock': 'MSFT', 'quantity': 10} +> +> Please respond with: +> - 'yes' or 'approve' to confirm +> - 'no' or 'deny' to cancel + +# Respond with approval +> yes + +# Stock purchase executes +> ✅ Successfully purchased 10 shares of MSFT. Order executed at market price. +``` + +## Code Overview + +```python +# Create your function +stock_function = Function( + name="buy_stock", + description="purchase stocks by specifying ticker symbol and quantity", + parameter_schema=BuyStockParams, + handler=lambda params: f"✅ Successfully purchased {params.quantity} shares of {params.stock}", +) + +# Wrap it with approval +approval_plugin = ApprovalPlugin( + sender=ctx, + fn_names=["buy_stock"] # Functions that need approval +) + +# Add to ChatPrompt +chat_prompt = ChatPrompt( + model=ai_model, + functions=[stock_function], + memory=memory, +).with_plugin(approval_plugin) + +# Use normally - approval happens automatically +if await chat_prompt.requires_resuming(): + result = await chat_prompt.resume(ctx.activity) +else: + result = await chat_prompt.send(ctx.activity.text) +``` + +## Key Benefits + +- ✅ **Clean code**: Just specify which functions need approval +- ✅ **No function modification**: Original functions stay unchanged +- ✅ **Automatic deferral**: Plugin handles all the deferred execution logic +- ✅ **Reusable**: Same plugin works across different ChatPrompts +- ✅ **Natural UX**: AI calls functions normally, approval is transparent diff --git a/tests/defferred_ai/pyproject.toml b/tests/defferred_ai/pyproject.toml new file mode 100644 index 00000000..d262dc5a --- /dev/null +++ b/tests/defferred_ai/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "defferred_ai" +version = "0.1.0" +description = "testing deferred tools" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "dotenv>=0.9.9", + "microsoft-teams-apps", +] + +[tool.uv.sources] +microsoft-teams-apps = { workspace = true } diff --git a/tests/defferred_ai/src/approval.py b/tests/defferred_ai/src/approval.py new file mode 100644 index 00000000..f875048d --- /dev/null +++ b/tests/defferred_ai/src/approval.py @@ -0,0 +1,165 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +from typing import Any, Callable, Protocol, TypeVar + +from microsoft.teams.ai import DeferredResult, Function +from microsoft.teams.ai.function import DeferredFunctionResumer +from microsoft.teams.api import MessageActivityInput +from microsoft.teams.api.activities.message.message import MessageActivity +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) + + +class MessageSender(Protocol): + """Protocol for anything that can send messages.""" + + async def send(self, message: str | MessageActivityInput) -> Any: + """Send a message.""" + ... + + +class ApprovalParams(BaseModel): + query: str + + +def create_approval_function(sender: MessageSender) -> Function[ApprovalParams]: + """Factory function to create an approval function with captured message sender.""" + + async def approval_handler(params: ApprovalParams) -> DeferredResult: + """Handler that defers execution and sends approval request.""" + # Send the approval request message immediately + await sender.send( + "⏳ **Approval Required**\n\n" + f"**Query:** {params.query}\n\n" + "Please respond with:\n" + "• 'yes' or 'approve' to confirm\n" + "• 'no' or 'deny' to cancel" + ) + + return DeferredResult( + state={"query": params.query}, + ) + + class HumanApprovalResumer(DeferredFunctionResumer[ApprovalParams, Any]): + """Resumer that handles human approval responses.""" + + def can_handle(self, activity: Any) -> bool: + """Check if this is a text message that looks like an approval response.""" + if isinstance(activity, MessageActivity): + text = activity.text.lower().strip() + approval_keywords = ["yes", "no", "approve", "deny", "reject", "confirm", "cancel"] + return any(keyword in text for keyword in approval_keywords) + return False + + async def __call__(self, activity: Any, resumable_data: dict[str, Any]) -> str: + """Process the human approval response.""" + assert isinstance(activity, MessageActivity), "activity must be a MessageActivity" + user_response = activity.text.lower().strip() + query = resumable_data.get("query", "unknown query") + + await sender.send("[DEBUG] got approval result from user") + if any(word in user_response for word in ["yes", "approve", "confirm"]): + return f"✅ Approved: {query}\nApproval granted by user." + else: + return f"❌ Denied: {query}\nApproval denied by user." + + return Function( + name="get_human_approval", + description=( + "You must ALWAYS use this tool to get approvals. Do NOT ask for approvaldirectly without using this tool" + ), + parameter_schema=ApprovalParams, + handler=approval_handler, + resumer=HumanApprovalResumer(), + ) + + +def create_approval_wrapped_function[T: BaseModel]( + sender: MessageSender, original_function: Function[T], create_approval_message: Callable[[T], str] +) -> Function[T]: + """ + Wrap an existing function with approval workflow. + + Args: + sender: Message sender for approval requests + original_function: The function to wrap with approval + create_approval_message: Function to create approval message based on params + + Returns: + A new function that requires approval before executing the original + """ + + async def wrapped_handler(params: T) -> DeferredResult: + """Handler that requests approval before executing the original function.""" + + # Create approval message using the provided callback + approval_message = create_approval_message(params) + + print(f"[APPROVAL WRAPPER] Requesting approval for: {original_function.name}") + + # Send the approval request message + await sender.send(approval_message) + + # Save the call details in state for resume + return DeferredResult( + state={ + "params": params.model_dump(), + }, + ) + + class ApprovalWrappedResumer(DeferredFunctionResumer[T, Any]): + """Resumer that executes the original function after approval.""" + + def can_handle(self, activity: Any) -> bool: + """Check if this is a text message that looks like an approval response.""" + if isinstance(activity, MessageActivity): + text = activity.text.lower().strip() + approval_keywords = ["yes", "no", "approve", "deny", "reject", "confirm", "cancel"] + return any(keyword in text for keyword in approval_keywords) + return False + + async def __call__(self, activity: Any, resumable_data: dict[str, Any]) -> str: + """Process the approval response and execute original function if approved.""" + assert isinstance(activity, MessageActivity), "expected activity to be a MessageActivity" + user_response = activity.text.lower().strip() + saved_params = resumable_data.get("params", {}) + + await sender.send("[DEBUG] Got approval result!") + if any(word in user_response for word in ["yes", "approve", "confirm"]): + print("[APPROVAL WRAPPER] Approved, executing original function") + + try: + # Recreate the params object and call original function + # Cast parameter_schema to the type since we know it should be T for Function[T] + schema_type = original_function.parameter_schema + if not isinstance(schema_type, type): + raise ValueError(f"Expected parameter_schema to be a type, got {type(schema_type)}") + + params_instance = schema_type(**saved_params) + result = original_function.handler(params_instance) + + # Handle async results + from inspect import isawaitable + + if isawaitable(result): + result = await result + + return f"✅ **Approved and Executed**\n\n{result}" + + except Exception as e: + return f"❌ **Approved but Failed**\nError executing {original_function.name}: {str(e)}" + else: + return "❌ **Cancelled**\nExecution denied by user." + + # Return wrapped function using original function's name, description, and param_schema + return Function[T]( + name=original_function.name, + description=original_function.description, + parameter_schema=original_function.parameter_schema, + handler=wrapped_handler, + resumer=ApprovalWrappedResumer(), + ) diff --git a/tests/defferred_ai/src/approval_for_function.py b/tests/defferred_ai/src/approval_for_function.py new file mode 100644 index 00000000..edfefe89 --- /dev/null +++ b/tests/defferred_ai/src/approval_for_function.py @@ -0,0 +1,170 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +import logging +from typing import Any, Protocol + +from microsoft.teams.ai import BaseAIPlugin, DeferredResult, Function, execute_function +from microsoft.teams.api import MessageActivityInput +from microsoft.teams.common.logging.console import ConsoleLogger +from pydantic import BaseModel + + +class MessageSender(Protocol): + """Protocol for anything that can send messages.""" + + async def send(self, message: str | MessageActivityInput) -> Any: + """Send a message.""" + ... + + +class ApprovalPlugin(BaseAIPlugin): + """ + Plugin that wraps specified functions with approval workflow. + + This plugin intercepts function calls, requests approval from the user, + and executes the original function only after approval is granted. + """ + + def __init__(self, sender: MessageSender, functions: list[Function[Any]], *, logger: logging.Logger | None = None): + """ + Initialize the approval plugin. + + Args: + sender: Message sender for sending approval requests + fn_names: List of function names to wrap with approval workflow + """ + super().__init__("approval") + self.sender = sender + self.logger: logging.Logger = logger or ConsoleLogger().create_logger("ApprovalPlugin") + self._original_functions: dict[str, Function[BaseModel]] = {f.name: f for f in functions} + + async def on_resume(self, function_name: str, activity: Any, state: dict[str, Any]) -> str | None: + """ + Handle approval responses when resuming deferred functions. + + Args: + function_name: Name of the function that was deferred + activity: Activity data to use for resolving + state: The state that was saved when function was deferred + + Returns: + Result string if this plugin handled the approval, None otherwise + """ + # Only handle functions we're wrapping + if function_name not in self._original_functions: + return None + + # Check if this activity has text (duck typing for MessageActivity) + if not hasattr(activity, "text") or not isinstance(activity.text, str): + return None + + text = activity.text.lower().strip() + approval_keywords = ["yes", "no", "approve", "deny", "reject", "confirm", "cancel"] + if not any(keyword in text for keyword in approval_keywords): + return None # Not an approval response yet + + # Handle approval/denial + if any(word in text for word in ["yes", "approve", "confirm"]): + return await self._execute_wrapped_function(function_name, state) + else: + return f"Denied: Execution of {function_name} was cancelled by user." + + async def on_build_functions(self, functions: list[Function[BaseModel]]) -> list[Function[BaseModel]] | None: + """ + Wrap specified functions with approval workflow. + + Args: + functions: Current list of available functions + + Returns: + Updated function list with wrapped functions + """ + # Wrap each specified function + wrapped_functions: list[Function[BaseModel]] = [] + for func in functions: + if func.name in self._original_functions: + if func.resumer is not None: + self.logger.warning( + f"{func.name} seems to be a resumable function. ApprovalPlugin only works" + "for functions that are not resumable themselves." + ) + continue + wrapped_func = self._create_wrapped_function(func) + wrapped_functions.append(wrapped_func) + else: + wrapped_functions.append(func) + + return wrapped_functions + + def _create_wrapped_function(self, original_func: Function[BaseModel]) -> Function[BaseModel]: + """ + Create a wrapped version of a function that requires approval. + + Args: + original_func: The original function to wrap + + Returns: + Wrapped function that defers for approval before execution + """ + # Store original function for later execution + + self.logger.debug(f"Wrapping {original_func.name} with ApprovalPlugin Function") + + async def wrapped_handler(params: BaseModel) -> DeferredResult: + """Handler that requests approval before executing original function.""" + # Send approval request + await self.sender.send( + f"Approval Required\n\n" + f"Function: {original_func.name}\n" + f"Parameters: {params.model_dump()}\n\n" + "Please respond with:\n" + "- 'yes' or 'approve' to confirm\n" + "- 'no' or 'deny' to cancel" + ) + + # Save params for later execution + return DeferredResult( + state={ + "params": params.model_dump(), + "original_function_name": original_func.name, + }, + ) + + return Function( + name=original_func.name, + description=original_func.description, + parameter_schema=original_func.parameter_schema, + handler=wrapped_handler, + resumer=None, # Plugin handles resuming via on_resume hook + ) + + async def _execute_wrapped_function(self, function_name: str, state: dict[str, Any]) -> str: + """ + Execute the original wrapped function after approval. + + Args: + function_name: Name of the function to execute + state: State containing saved parameters + + Returns: + Result from executing the original function + """ + original_func = self._original_functions.get(function_name) + if not original_func: + raise ValueError(f"Could not re-run original function {function_name} because it no longer exists") + try: + # Recreate params from saved state + saved_params = state.get("params", {}) + self.logger.info(f"Running original function {function_name} after approval") + result = await execute_function(original_func, saved_params) + if isinstance(result, DeferredResult): + raise ValueError( + "Functions that use ApprovalPlugin cannot be deferrable!" + f"And {original_func.name} just returned a DeferredResult" + ) + return result + except Exception as e: + return f"Approved but Failed\nError executing {function_name}: {str(e)}" diff --git a/tests/defferred_ai/src/main.py b/tests/defferred_ai/src/main.py new file mode 100644 index 00000000..91074035 --- /dev/null +++ b/tests/defferred_ai/src/main.py @@ -0,0 +1,100 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +import asyncio +from os import getenv + +from approval_for_function import ApprovalPlugin +from dotenv import find_dotenv, load_dotenv +from microsoft.teams.ai import ChatPrompt, Function, ListMemory +from microsoft.teams.api import MessageActivity, MessageActivityInput +from microsoft.teams.apps import ActivityContext, App +from microsoft.teams.devtools import DevToolsPlugin +from microsoft.teams.openai import OpenAICompletionsAIModel +from pydantic import BaseModel + +load_dotenv(find_dotenv(usecwd=True)) + + +app = App(plugins=[DevToolsPlugin()]) + + +def get_required_env(key: str) -> str: + value = getenv(key) + if not value: + raise ValueError(f"Required environment variable {key} is not set") + return value + + +# Get OpenAI model (like in ai-test) +AZURE_OPENAI_MODEL = get_required_env("AZURE_OPENAI_MODEL") +ai_model = OpenAICompletionsAIModel(model=AZURE_OPENAI_MODEL) + + +class BuyStockParams(BaseModel): + stock: str + quantity: int + + +def create_buy_stock_function() -> Function[BuyStockParams]: + """Create a buy stock function.""" + + def handler(params: BuyStockParams) -> str: + print("Actually running the buy stock fn") + return f"✅ Successfully purchased {params.quantity} shares of {params.stock}. Order executed at market price." + + return Function( + name="buy_stock", + description="purchase stocks by specifying ticker symbol and quantity", + parameter_schema=BuyStockParams, + handler=handler, + ) + + +# Global memory instance +memory = ListMemory() + + +@app.on_message +async def handle_stock_trading(ctx: ActivityContext[MessageActivity]) -> None: + """Handle stock trading with approval using ApprovalPlugin.""" + print(f"[STOCK TRADING] Message received: {ctx.activity.text}") + + try: + # Create stock function (will be wrapped by plugin) + stock_function = create_buy_stock_function() + + # Create approval plugin with fn_names to wrap + approval_plugin = ApprovalPlugin(sender=ctx, functions=[stock_function]) + + chat_prompt = ChatPrompt( + instructions=( + "You are a helpful assistant. Use the available stock trading tool when users want to buy stocks." + ), + model=ai_model, + functions=[stock_function], # Plugin will wrap this function + memory=memory, + ).with_plugin(approval_plugin) + + # Handle deferred functions or normal chat + if await chat_prompt.requires_resuming(): + chat_result = await chat_prompt.resume(ctx.activity) + else: + chat_result = await chat_prompt.send(input=ctx.activity.text) + + if chat_result.response and chat_result.response.content: + message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() + await ctx.send(message) + elif chat_result.is_deferred: + # Approval message already sent by the plugin + pass + + except Exception as e: + print(f"[STOCK TRADING] Error: {str(e)}") + await ctx.send(f"❌ Error: {str(e)}") + + +if __name__ == "__main__": + asyncio.run(app.start()) diff --git a/tests/mcp-client/src/main.py b/tests/mcp-client/src/main.py index ec6989ba..30b61021 100644 --- a/tests/mcp-client/src/main.py +++ b/tests/mcp-client/src/main.py @@ -88,7 +88,7 @@ async def handle_agent_chat(ctx: ActivityContext[MessageActivity]): # Use Agent with MCP tools (stateful conversation) result = await responses_agent.send(query) - if result.response.content: + if result.response and result.response.content: message = MessageActivityInput(text=result.response.content).add_ai_generated() await ctx.send(message) @@ -111,7 +111,7 @@ async def handle_prompt_chat(ctx: ActivityContext[MessageActivity]): ), ) - if result.response.content: + if result.response and result.response.content: message = MessageActivityInput(text=result.response.content).add_ai_generated() await ctx.send(message) @@ -157,7 +157,7 @@ async def handle_fallback_message(ctx: ActivityContext[MessageActivity]): # Use Agent with MCP tools for general conversation result = await responses_agent.send(ctx.activity.text) - if result.response.content: + if result.response and result.response.content: message = MessageActivityInput(text=result.response.content).add_ai_generated() await ctx.send(message) diff --git a/uv.lock b/uv.lock index 399b737f..ae2a82c6 100644 --- a/uv.lock +++ b/uv.lock @@ -10,6 +10,7 @@ resolution-markers = [ members = [ "ai-test", "cards", + "defferred-ai", "dialogs", "echo", "graph", @@ -591,6 +592,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/67/ac57fbef5414ce84fe0bdeb497918ab2c781ff2cbf23c1bd91334b225669/cyclopts-3.23.1-py3-none-any.whl", hash = "sha256:8e57c6ea47d72b4b565c6a6c8a9fd56ed048ab4316627991230f4ad24ce2bc29", size = 85222, upload-time = "2025-08-30T17:40:33.005Z" }, ] +[[package]] +name = "defferred-ai" +version = "0.1.0" +source = { virtual = "tests/defferred_ai" } +dependencies = [ + { name = "dotenv" }, + { name = "microsoft-teams-apps" }, +] + +[package.metadata] +requires-dist = [ + { name = "dotenv", specifier = ">=0.9.9" }, + { name = "microsoft-teams-apps", editable = "packages/apps" }, +] + [[package]] name = "dependency-injector" version = "4.48.1"