Skip to content

Commit a981b5a

Browse files
Cheuk Lun Koccreutzi
authored andcommitted
Minor updates to agent examples
1 parent adfe496 commit a981b5a

File tree

7 files changed

+60
-59
lines changed

7 files changed

+60
-59
lines changed

examples/FitPolynomialToDataUsingAIAgentExample.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -420,14 +420,14 @@ while ~problemSolved
420420
history = addResponseMessage(history,completeOutput);
421421
422422
% Action
423-
history = addUserMessage(history,"Call tools to solve the problem.");
423+
history = addUserMessage(history,"Execute the next step.");
424424
[~,completeOutput] = generate(llm,history,ToolChoice="required");
425425
history = addResponseMessage(history,completeOutput);
426426
actions = completeOutput.tool_calls;
427427
if isscalar(actions) && strcmp(actions(1).function.name,"finalAnswer")
428428
history = addToolMessage(history,actions.id,"finalAnswer","Final answer below");
429429
history = addUserMessage(history,"Return the final answer concisely.");
430-
agentResponse = generate(llm,history,ResponseFormat=responseFormat);
430+
agentResponse = generate(llm,history,ToolChoice="none",ResponseFormat=responseFormat);
431431
problemSolved = true;
432432
else
433433
for i = 1:numel(actions)
@@ -466,13 +466,18 @@ catch
466466
end
467467
468468
% Validate tool parameters
469-
toolSpec = toolRegistry(toolName);
470-
requiredArgs = string(fieldnames(toolSpec.toolSpecification.Parameters));
469+
tool = toolRegistry(toolName);
470+
requiredArgs = string(fieldnames(tool.toolSpecification.Parameters));
471471
assert(all(isfield(args,requiredArgs)),"Invalid tool parameters: %s",strjoin(fieldnames(args),","))
472472
473+
extraArgs = setdiff(string(fieldnames(args)),requiredArgs);
474+
if ~isempty(extraArgs)
475+
warning("Ignoring extra tool parameters: %s",strjoin(extraArgs,","));
476+
end
477+
473478
% Execute tool
474-
argValues = cellfun(@(fieldName) args.(fieldName),cellstr(requiredArgs),UniformOutput=false);
475-
functionHandle = toolSpec.functionHandle;
479+
argValues = arrayfun(@(fieldName) args.(fieldName),requiredArgs,UniformOutput=false);
480+
functionHandle = tool.functionHandle;
476481
nout = nargout(functionHandle);
477482
if nout == 2
478483
[data,observation] = functionHandle(data,argValues{:});
93 Bytes
Loading

examples/SolveSimpleMathProblemUsingAIAgent.md

Lines changed: 49 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ The `solveQuadraticEquation` function takes the three coefficients of a quadrati
4747

4848
```matlab
4949
function r = solveQuadraticEquation(a,b,c)
50-
r = roots([a b c]);
50+
r = roots([a b c]);
5151
end
5252
```
5353

@@ -98,38 +98,34 @@ toolRegistry("smallestRealNumber") = struct( ...
9898
"functionHandle",@smallestRealNumber);
9999
```
100100

101-
Define a function to evaluate tool calls identified by the LLM. LLMs can hallucinate tool calls or make errors about the parameters that the tools need. Therefore, first validate the tool name and parameters by comparing them to the `toolRegistry` dictionary. Then, run the functions associated with the tools using the [`feval`](https://www.mathworks.com/help/matlab/ref/feval.html) function.
101+
Define a function to evaluate tool calls identified by the LLM. LLMs can hallucinate tool calls or make errors about the parameters that the tools need. Therefore, first validate the tool name and parameters by comparing them to the `toolRegistry` dictionary. Then, run the functions associated with the tools. Return the result as an observation to the agent.
102102

103103
```matlab
104-
function result = evaluateToolCall(toolCall,toolRegistry)
105-
% Validate tool name
106-
toolName = toolCall.function.name;
107-
assert(isKey(toolRegistry,toolName),"Invalid tool name ''%s''.",toolName)
108-
109-
% Validate JSON syntax
110-
try
111-
args = jsondecode(toolCall.function.arguments);
112-
catch
113-
error("Model returned invalid JSON syntax for arguments of tool ''%s''.",toolName);
114-
end
115-
116-
% Validate tool parameters
117-
tool = toolRegistry(toolName);
118-
requiredArgs = string(fieldnames(tool.toolSpecification.Parameters));
119-
assert(all(isfield(args,requiredArgs)),"Invalid tool parameters: %s",strjoin(fieldnames(args),","))
120-
121-
extraArgs = setdiff(string(fieldnames(args)),requiredArgs);
122-
if ~isempty(extraArgs)
123-
warning("Ignoring extra tool parameters: %s",strjoin(extraArgs,","));
124-
end
125-
126-
% Execute tool
127-
argValues = arrayfun(@(fieldName) args.(fieldName),requiredArgs,UniformOutput=false);
128-
try
129-
result = feval(tool.functionHandle,argValues{:});
130-
catch ME
131-
error("Tool call '%s' failed with error: %s",toolName,ME.message)
132-
end
104+
function observation = evaluateToolCall(toolCall,toolRegistry)
105+
% Validate tool name
106+
toolName = toolCall.function.name;
107+
assert(isKey(toolRegistry,toolName),"Invalid tool name ''%s''.",toolName)
108+
109+
% Validate JSON syntax
110+
try
111+
args = jsondecode(toolCall.function.arguments);
112+
catch
113+
error("Model returned invalid JSON syntax for arguments of tool ''%s''.",toolName);
114+
end
115+
116+
% Validate tool parameters
117+
tool = toolRegistry(toolName);
118+
requiredArgs = string(fieldnames(tool.toolSpecification.Parameters));
119+
assert(all(isfield(args,requiredArgs)),"Invalid tool parameters: %s",strjoin(fieldnames(args),","))
120+
121+
extraArgs = setdiff(string(fieldnames(args)),requiredArgs);
122+
if ~isempty(extraArgs)
123+
warning("Ignoring extra tool parameters: %s",strjoin(extraArgs,","));
124+
end
125+
126+
% Execute tool
127+
argValues = arrayfun(@(fieldName) args.(fieldName),requiredArgs,UniformOutput=false);
128+
observation = tool.functionHandle(argValues{:});
133129
end
134130
```
135131
# Set Up ReAct Agent
@@ -149,10 +145,10 @@ This architecture is an iterative workflow. For each iteration, the agent perfor
149145
2. Action — The agent executes the next action.
150146
3. Observation — The agent observes the tool output.
151147

152-
Define the function `runAgent` that answers a user query `userQuery` using the ReAct agent architecture and the tools provided in `toolRegistry`.
148+
Define the function `runReActAgent` that answers a user query `userQuery` using the ReAct agent architecture and the tools provided in `toolRegistry`.
153149

154150
```matlab
155-
function agentResponse = runAgent(userQuery,toolRegistry)
151+
function agentResponse = runReActAgent(userQuery,toolRegistry)
156152
```
157153

158154
To ensure the agent stops after it answers the user query, create a tool `finalAnswer` and add it to the tool list.
@@ -170,10 +166,10 @@ systemPrompt = ...
170166
"Solve the problem. When done, call the tool finalAnswer else you will get stuck in a loop.";
171167
```
172168

173-
Connect to the OpenAI Chat Completion API using the [`openAIChat`](../doc/functions/openAIChat.md) function. Use the OpenAI model `"gpt-4.1"`. Provide the LLM with tools using the `Tools` name\-value argument. Initialize the message history.
169+
Connect to the OpenAI Chat Completion API using the [`openAIChat`](../doc/functions/openAIChat.md) function. Use the OpenAI model GPT\-4.1 mini. Provide the LLM with tools using the `Tools` name\-value argument. Initialize the message history.
174170

175171
```matlab
176-
chat = openAIChat(systemPrompt,ModelName="gpt-4.1",Tools=tools);
172+
llm = openAIChat(systemPrompt,ModelName="gpt-4.1-mini",Tools=tools);
177173
history = messageHistory;
178174
```
179175

@@ -200,28 +196,28 @@ while ~problemSolved
200196
Instruct the agent to plan the next step. Generate a response from the message history. To ensure the agent outputs text, set the `ToolChoice` name\-value argument to `"none"`.
201197

202198
```matlab
203-
history = addUserMessage(history,"Plan your next step.");
204-
[thought,completeOutput] = generate(chat,history,ToolChoice="none");
199+
history = addUserMessage(history,"Plan your single next step concisely.");
200+
[thought,completeOutput] = generate(llm,history,ToolChoice="none");
205201
disp("[Thought] " + thought);
206202
history = addResponseMessage(history,completeOutput);
207203
```
208204

209-
Instruct the agent to call a tool. Instruct the agent to always call a tool in this step.
205+
Instruct the agent to always call a tool.
210206

211207
```matlab
212-
history = addUserMessage(history,"Call tools to solve the problem. Always call a tool.");
213-
[~,completeOutput] = generate(chat,history);
208+
history = addUserMessage(history,"Execute the next step.");
209+
[~,completeOutput] = generate(llm,history,ToolChoice="required");
214210
history = addResponseMessage(history,completeOutput);
215211
actions = completeOutput.tool_calls;
216212
```
217213

218-
If the agent calls the `finalAnswer` tool, add the return the final agent response to the message history.
214+
If the agent calls the `finalAnswer` tool, add the final agent response to the message history.
219215

220216
```matlab
221217
if isscalar(actions) && strcmp(actions(1).function.name,"finalAnswer")
222218
history = addToolMessage(history,actions.id,"finalAnswer","Final answer below");
223219
history = addUserMessage(history,"Return the final answer as a statement.");
224-
agentResponse = generate(chat,history,ToolChoice="none");
220+
agentResponse = generate(llm,history,ToolChoice="none");
225221
problemSolved = true;
226222
```
227223

@@ -239,8 +235,8 @@ Otherwise, log and evaluate each tool call in the agent output.
239235
To enable the agent to observe the output, add the tool call result to the message history.
240236

241237
```matlab
242-
fprintf("[Observation] Result from tool '%s': %s\n",toolName,jsonencode(string(observation)));
243-
history = addToolMessage(history,action.id,toolName,"Observation: " + jsonencode(string(observation)));
238+
fprintf("[Observation] Result from tool '%s': %s\n",toolName,jsonencode(observation));
239+
history = addToolMessage(history,action.id,toolName,"Observation: " + jsonencode(observation));
244240
end
245241
end
246242
end
@@ -252,20 +248,20 @@ Define the query. Answer the query using the agent.
252248

253249
```matlab
254250
userQuery = "What is the smallest root of x^2+2x-3=0?";
255-
agentResponse = runAgent(userQuery,toolRegistry);
251+
agentResponse = runReActAgent(userQuery,toolRegistry);
256252
```
257253

258254
```matlabTextOutput
259255
User: What is the smallest root of x^2+2x-3=0?
260-
[Thought] To find the smallest root, I will:
261-
1. Solve the quadratic equation x^2 + 2x - 3 = 0 to find both roots.
262-
2. Compare the two roots and select the smallest one.
256+
[Thought] To find the smallest root of the quadratic equation \(x^2 + 2x - 3 = 0\), I will first solve the quadratic equation to find both roots. Then, I will determine the smallest root.
257+
258+
I will start by solving the quadratic equation.
263259
[Action] Calling tool 'solveQuadraticEquation' with args: "{\"a\":1,\"b\":2,\"c\":-3}"
264-
[Observation] Result from tool 'solveQuadraticEquation': ["-3","1"]
265-
[Thought] Now that I have both roots (-3 and 1), I will compare them to determine which is the smallest root. Then I will provide the smallest root as the final answer.
260+
[Observation] Result from tool 'solveQuadraticEquation': [-3.0000000000000004,0.99999999999999978]
261+
[Thought] I have found the roots of the equation to be approximately -3 and 1. Now, I will identify the smallest root between these two values.
266262
[Action] Calling tool 'smallestRealNumber' with args: "{\"x1\":\"-3\",\"x2\":\"1\"}"
267-
[Observation] Result from tool 'smallestRealNumber': "-3"
268-
[Thought] I have identified -3 as the smallest root. My next step is to provide -3 as the final answer.
263+
[Observation] Result from tool 'smallestRealNumber': -3
264+
[Thought] The smallest root of the quadratic equation \(x^2 + 2x - 3 = 0\) is -3. I will now provide this as the final answer.
269265
```
270266

271267

@@ -276,7 +272,7 @@ disp(agentResponse);
276272
```
277273

278274
```matlabTextOutput
279-
The smallest root of the equation x^2 + 2x - 3 = 0 is -3.
275+
The smallest root of the equation \(x^2 + 2x - 3 = 0\) is -3.
280276
```
281277

282278
# References
-53 Bytes
Binary file not shown.
-52 Bytes
Binary file not shown.
-19 KB
Binary file not shown.
-628 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)