-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathagent_and_run_level_middleware.py
More file actions
299 lines (240 loc) · 11.3 KB
/
agent_and_run_level_middleware.py
File metadata and controls
299 lines (240 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentContext,
AgentMiddleware,
AgentResponse,
FunctionInvocationContext,
tool,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
# Load environment variables from .env file
load_dotenv()
"""
Agent-Level and Run-Level MiddlewareTypes Example
This sample demonstrates the difference between agent-level and run-level middleware:
- Agent-level middleware: Applied to ALL runs of the agent (persistent across runs)
- Run-level middleware: Applied to specific runs only (isolated per run)
The example shows:
1. Agent-level security middleware that validates all requests
2. Agent-level performance monitoring across all runs
3. Run-level context middleware for specific use cases (high priority, debugging)
4. Run-level caching middleware for expensive operations
Agent Middleware Execution Order:
When both agent-level and run-level *agent* middleware are configured, they execute
in this order:
1. Agent-level middleware (outermost) - executes first, in the order they were registered
2. Run-level middleware (innermost) - executes next, in the order they were passed to run()
3. Agent execution - the actual agent logic runs last
For example, with agent middleware [A1, A2] and run middleware [R1, R2]:
Request -> A1 -> A2 -> R1 -> R2 -> Agent -> R2 -> R1 -> A2 -> A1 -> Response
This means:
- Agent middleware wraps ALL run middleware and the agent
- Run middleware wraps only the agent for that specific run
- Each middleware can modify the context before AND after calling next()
Note: Function middleware executes during tool invocation, and chat middleware
executes around each model call inside the agent execution, not in the outer
agent-middleware chain shown above. They follow the same ordering principle:
agent-level function/chat middleware runs before run-level function/chat middleware.
"""
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
# see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
# Agent-level middleware (applied to ALL runs)
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent-level security middleware that validates all requests."""
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
print("[SecurityMiddleware] Checking security for all requests...")
# Check for security violations in the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text.lower()
if any(word in query for word in ["password", "secret", "credentials"]):
print("[SecurityMiddleware] Security violation detected! Blocking request.")
return # Don't call call_next() to prevent execution
print("[SecurityMiddleware] Security check passed.")
context.metadata["security_validated"] = True
await call_next()
async def performance_monitor_middleware(
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Agent-level performance monitoring for all runs."""
print("[PerformanceMonitor] Starting performance monitoring...")
start_time = time.time()
await call_next()
end_time = time.time()
duration = end_time - start_time
print(f"[PerformanceMonitor] Total execution time: {duration:.3f}s")
context.metadata["execution_time"] = duration
# Run-level middleware (applied to specific runs only)
class HighPriorityMiddleware(AgentMiddleware):
"""Run-level middleware for high priority requests."""
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
print("[HighPriority] Processing high priority request with expedited handling...")
# Read metadata set by agent-level middleware
if context.metadata.get("security_validated"):
print("[HighPriority] Security validation confirmed from agent middleware")
# Set high priority flag
context.metadata["priority"] = "high"
context.metadata["expedited"] = True
await call_next()
print("[HighPriority] High priority processing completed")
async def debugging_middleware(
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Run-level debugging middleware for troubleshooting specific runs."""
print("[Debug] Debug mode enabled for this run")
print(f"[Debug] Messages count: {len(context.messages)}")
print(f"[Debug] Is streaming: {context.stream}")
# Log existing metadata from agent middleware
if context.metadata:
print(f"[Debug] Existing metadata: {context.metadata}")
context.metadata["debug_enabled"] = True
await call_next()
print("[Debug] Debug information collected")
class CachingMiddleware(AgentMiddleware):
"""Run-level caching middleware for expensive operations."""
def __init__(self) -> None:
self.cache: dict[str, AgentResponse] = {}
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Create a simple cache key from the last message
last_message = context.messages[-1] if context.messages else None
cache_key: str = last_message.text if last_message and last_message.text else "no_message"
if cache_key in self.cache:
print(f"[Cache] Cache HIT for: '{cache_key[:30]}...'")
context.result = self.cache[cache_key] # type: ignore
return # Don't call call_next(), return cached result
print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'")
context.metadata["cache_key"] = cache_key
await call_next()
# Cache the result if we have one
if context.result:
self.cache[cache_key] = context.result # type: ignore
print("[Cache] Result cached for future use")
async def function_logging_middleware(
context: FunctionInvocationContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function middleware that logs all function calls."""
function_name = context.function.name
args = context.arguments
print(f"[FunctionLog] Calling function: {function_name} with args: {args}")
await call_next()
print(f"[FunctionLog] Function {function_name} completed")
async def main() -> None:
"""Example demonstrating agent-level and run-level middleware."""
print("=== Agent-Level and Run-Level MiddlewareTypes Example ===\n")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
# Agent-level middleware: applied to ALL runs
middleware=[
SecurityAgentMiddleware(),
performance_monitor_middleware,
function_logging_middleware,
],
) as agent,
):
print("Agent created with agent-level middleware:")
print(" - SecurityMiddleware (blocks sensitive requests)")
print(" - PerformanceMonitor (tracks execution time)")
print(" - FunctionLogging (logs all function calls)")
print()
# Run 1: Normal query with no run-level middleware
print("=" * 60)
print("RUN 1: Normal query (agent-level middleware only)")
print("=" * 60)
query = "What's the weather like in Paris?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 2: High priority request with run-level middleware
print("=" * 60)
print("RUN 2: High priority request (agent + run-level middleware)")
print("=" * 60)
query = "What's the weather in Tokyo? This is urgent!"
print(f"User: {query}")
result = await agent.run(
query,
middleware=[HighPriorityMiddleware()], # Run-level middleware
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 3: Debug mode with run-level debugging middleware
print("=" * 60)
print("RUN 3: Debug mode (agent + run-level debugging)")
print("=" * 60)
query = "What's the weather in London?"
print(f"User: {query}")
result = await agent.run(
query,
middleware=[debugging_middleware], # Run-level middleware
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 4: Multiple run-level middleware
print("=" * 60)
print("RUN 4: Multiple run-level middleware (caching + debug)")
print("=" * 60)
caching = CachingMiddleware()
query = "What's the weather in New York?"
print(f"User: {query}")
result = await agent.run(
query,
middleware=[caching, debugging_middleware], # Multiple run-level middleware
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 5: Test cache hit with same query
print("=" * 60)
print("RUN 5: Test cache hit (same query as Run 4)")
print("=" * 60)
print(f"User: {query}") # Same query as Run 4
result = await agent.run(
query,
middleware=[caching], # Same caching middleware instance
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 6: Security violation test
print("=" * 60)
print("RUN 6: Security test (should be blocked by agent middleware)")
print("=" * 60)
query = "What's the secret weather password for Berlin?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result and result.text else 'Request was blocked by security middleware'}")
print()
# Run 7: Normal query again (no run-level middleware interference)
print("=" * 60)
print("RUN 7: Normal query again (agent-level middleware only)")
print("=" * 60)
query = "What's the weather in Sydney?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
if __name__ == "__main__":
asyncio.run(main())