forked from lastmile-ai/mcp-agent
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparallel_llm.py
More file actions
280 lines (251 loc) · 11.1 KB
/
parallel_llm.py
File metadata and controls
280 lines (251 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
from typing import Any, Callable, List, Optional, Type, TYPE_CHECKING
from mcp_agent.agents.agent import Agent
from mcp_agent.tracing.semconv import GEN_AI_AGENT_NAME
from mcp_agent.tracing.telemetry import (
get_tracer,
record_attributes,
serialize_attributes,
)
from mcp_agent.workflows.llm.augmented_llm import (
AugmentedLLM,
MessageParamT,
MessageT,
ModelT,
RequestParams,
)
from mcp_agent.workflows.parallel.fan_in import FanInInput, FanIn
from mcp_agent.workflows.parallel.fan_out import FanOut
if TYPE_CHECKING:
from mcp_agent.core.context import Context
class ParallelLLM(AugmentedLLM[MessageParamT, MessageT]):
"""
LLMs can sometimes work simultaneously on a task (fan-out)
and have their outputs aggregated programmatically (fan-in).
This workflow performs both the fan-out and fan-in operations using LLMs.
From the user's perspective, an input is specified and the output is returned.
When to use this workflow:
Parallelization is effective when the divided subtasks can be parallelized
for speed (sectioning), or when multiple perspectives or attempts are needed for
higher confidence results (voting).
Examples:
Sectioning:
- Implementing guardrails where one model instance processes user queries
while another screens them for inappropriate content or requests.
- Automating evals for evaluating LLM performance, where each LLM call
evaluates a different aspect of the model’s performance on a given prompt.
Voting:
- Reviewing a piece of code for vulnerabilities, where several different
agents review and flag the code if they find a problem.
- Evaluating whether a given piece of content is inappropriate,
with multiple agents evaluating different aspects or requiring different
vote thresholds to balance false positives and negatives.
"""
def __init__(
self,
fan_in_agent: Agent | AugmentedLLM | Callable[[FanInInput], Any],
fan_out_agents: List[Agent | AugmentedLLM] | None = None,
fan_out_functions: List[Callable] | None = None,
name: str | None = None,
llm_factory: Callable[[Agent], AugmentedLLM] = None,
context: Optional["Context"] = None,
**kwargs,
):
"""
Initialize the LLM with a list of server names and an instruction.
If a name is provided, it will be used to identify the LLM.
If an agent is provided, all other properties are optional
"""
super().__init__(
name=name,
instruction="You are a parallel LLM workflow that can fan-out to multiple LLMs and fan-in to an aggregator LLM.",
context=context,
**kwargs,
)
self.llm_factory = llm_factory
self.fan_in_agent = fan_in_agent
self.fan_out_agents = fan_out_agents
self.fan_out_functions = fan_out_functions
self.history = (
None # History tracking is complex in this workflow, so it is not supported
)
self.fan_in_fn: Callable[[FanInInput], Any] = None
self.fan_in: FanIn = None
if isinstance(fan_in_agent, Callable):
self.fan_in_fn = fan_in_agent
else:
self.fan_in = FanIn(
aggregator_agent=fan_in_agent,
llm_factory=llm_factory,
context=context,
)
self.fan_out = FanOut(
agents=fan_out_agents,
functions=fan_out_functions,
llm_factory=llm_factory,
context=context,
)
async def generate(
self,
message: str | MessageParamT | List[MessageParamT],
request_params: RequestParams | None = None,
) -> List[MessageT] | Any:
tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.generate"
) as span:
if self.context.tracing_enabled:
span.set_attribute(GEN_AI_AGENT_NAME, self.agent.name)
self._annotate_span_for_generation_message(span, message)
if request_params:
AugmentedLLM.annotate_span_with_request_params(span, request_params)
# First, we fan-out
responses = await self.fan_out.generate(
message=message,
request_params=request_params,
)
if self.context.tracing_enabled:
for agent_name, fan_out_responses in responses.items():
res_attributes = {}
for i, res in enumerate(fan_out_responses):
try:
res_dict = (
res if isinstance(res, dict) else res.model_dump()
)
res_attributes.update(
serialize_attributes(res_dict, f"response.{i}")
)
# pylint: disable=broad-exception-caught
except Exception:
# Just no-op, best-effort tracing
continue
span.add_event(f"fan_out.{agent_name}.responses", res_attributes)
# Then, we fan-in
if self.fan_in_fn:
result = await self.fan_in_fn(responses)
else:
result = await self.fan_in.generate(
messages=responses,
request_params=request_params,
)
if self.context.tracing_enabled:
try:
if isinstance(result, list):
for i, res in enumerate(result):
res_dict = (
res if isinstance(res, dict) else res.model_dump()
)
record_attributes(span, res_dict, f"response.{i}")
else:
res_dict = (
result if isinstance(result, dict) else result.model_dump()
)
record_attributes(span, res_dict, "response")
# pylint: disable=broad-exception-caught
except Exception:
# Just no-op, best-effort tracing
pass
return result
async def generate_str(
self,
message: str | MessageParamT | List[MessageParamT],
request_params: RequestParams | None = None,
) -> str:
"""Request an LLM generation and return the string representation of the result"""
tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.generate_str"
) as span:
if self.context.tracing_enabled:
span.set_attribute(GEN_AI_AGENT_NAME, self.agent.name)
self._annotate_span_for_generation_message(span, message)
if request_params:
AugmentedLLM.annotate_span_with_request_params(span, request_params)
# First, we fan-out
responses = await self.fan_out.generate(
message=message,
request_params=request_params,
)
if self.context.tracing_enabled:
for agent_name, fan_out_responses in responses.items():
res_attributes = {}
for i, res in enumerate(fan_out_responses):
try:
res_dict = (
res if isinstance(res, dict) else res.model_dump()
)
res_attributes.update(
serialize_attributes(res_dict, f"response.{i}")
)
# pylint: disable=broad-exception-caught
except Exception:
# Just no-op, best-effort tracing
continue
span.add_event(f"fan_out.{agent_name}.responses", res_attributes)
# Then, we fan-in
if self.fan_in_fn:
result = str(await self.fan_in_fn(responses))
else:
result = await self.fan_in.generate_str(
messages=responses,
request_params=request_params,
)
span.set_attribute("response", result)
return result
async def generate_structured(
self,
message: str | MessageParamT | List[MessageParamT],
response_model: Type[ModelT],
request_params: RequestParams | None = None,
) -> ModelT:
"""Request a structured LLM generation and return the result as a Pydantic model."""
tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.generate_structured"
) as span:
if self.context.tracing_enabled:
self._annotate_span_for_generation_message(span, message)
span.set_attribute(
"response_model",
f"{response_model.__module__}.{response_model.__name__}",
)
if request_params:
AugmentedLLM.annotate_span_with_request_params(span, request_params)
# First, we fan-out
responses = await self.fan_out.generate(
message=message,
request_params=request_params,
)
if self.context.tracing_enabled:
for agent_name, fan_out_responses in responses.items():
res_attributes = {}
for i, res in enumerate(fan_out_responses):
try:
res_dict = (
res if isinstance(res, dict) else res.model_dump()
)
res_attributes.update(
serialize_attributes(res_dict, f"response.{i}")
)
# pylint: disable=broad-exception-caught
except Exception:
# Just no-op, best-effort tracing
continue
span.add_event(f"fan_out.{agent_name}.responses", res_attributes)
# Then, we fan-in
if self.fan_in_fn:
result = await self.fan_in_fn(responses)
else:
result = await self.fan_in.generate_structured(
messages=responses,
response_model=response_model,
request_params=request_params,
)
if self.context.tracing_enabled:
try:
span.set_attribute(
"structured_response_json", result.model_dump_json()
)
# pylint: disable=broad-exception-caught
except Exception:
pass # Just no-op, best-effort tracing
return result