Skip to content

Commit 55be38c

Browse files
authored
Test episodic memory (#26)
* test init for episodic memo * final test episodic memory * style fix to test_episodic_memory.py
1 parent f91e22e commit 55be38c

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import json
2+
from collections import deque
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
7+
from mesa_llm.memory.episodic_memory import EpisodicMemory
8+
from mesa_llm.memory.memory import MemoryEntry
9+
10+
11+
@pytest.fixture
12+
def mock_agent():
13+
agent = MagicMock(name="MockLLMAgent")
14+
15+
# Create a MagicMock for the LLM's response
16+
mock_response = MagicMock()
17+
18+
# This line *defines* the full nested path on the mock
19+
mock_response.choices[0].message.content = json.dumps({"grade": 3})
20+
21+
# Set this as the return value
22+
agent.llm.generate.return_value = mock_response
23+
24+
agent.model.steps = 100
25+
return agent
26+
27+
28+
class TestEpisodicMemory:
29+
"""Core functionality test"""
30+
31+
def test_memory_init(self, mock_agent):
32+
"""Test EpisodicMemory class initialization with defaults and custom values"""
33+
memory = EpisodicMemory(
34+
agent=mock_agent,
35+
max_capacity=10,
36+
considered_entries=5,
37+
llm_model="provider/test_model",
38+
)
39+
40+
assert memory.agent == mock_agent
41+
assert memory.max_capacity == 10
42+
assert memory.considered_entries == 5
43+
assert isinstance(memory.memory_entries, deque)
44+
assert memory.memory_entries.maxlen == 10
45+
assert memory.system_prompt is not None
46+
"""FYI: The above line may not always work; use the one below if needed."""
47+
# assert isinstance(memory.system_prompt,str), memory.system_prompt.strip() != ""
48+
49+
def test_add_memory_entry(self, mock_agent):
50+
"""Test adding memories to Episodic memory"""
51+
memory = EpisodicMemory(agent=mock_agent, llm_model="provider/test_model")
52+
53+
# Test basic addition with observation
54+
memory.add_to_memory("observation", {"step": 1, "content": "Test content"})
55+
56+
# Test with planning
57+
memory.add_to_memory("planning", {"plan": "Test plan", "importance": "high"})
58+
59+
# Test with action
60+
memory.add_to_memory("action", {"action": "Test action"})
61+
62+
# Should be empty step_content initially
63+
assert memory.step_content != {}
64+
65+
def test_grade_event_importance(self, mock_agent):
66+
"""Test grading event importance"""
67+
memory = EpisodicMemory(agent=mock_agent, llm_model="provider/test_model")
68+
69+
# 1. Set up a specific grade for this test
70+
mock_response = MagicMock()
71+
mock_response.choices[0].message.content = json.dumps({"grade": 5})
72+
mock_agent.llm.generate.return_value = mock_response
73+
74+
# 2. Call the method
75+
grade = memory.grade_event_importance("observation", {"data": "critical info"})
76+
77+
# 3. Assert the grade is correct
78+
assert grade == 5
79+
80+
# 4. Assert the LLM was called correctly
81+
mock_agent.llm.generate.assert_called_once()
82+
83+
# Check that the system prompt was set on the llm object
84+
assert memory.llm.system_prompt == memory.system_prompt
85+
86+
def test_retrieve_top_k_entries(self, mock_agent):
87+
"""Test the sorting logic for retrieving entries (importance - recency_penalty)."""
88+
memory = EpisodicMemory(agent=mock_agent, llm_model="provider/test_model")
89+
# Set current step
90+
mock_agent.model.steps = 100
91+
92+
# Manually add entries to bypass grading and control scores
93+
# score = importance - (current_step - entry_step)
94+
95+
# score = 5 - (100 - 98) = 3
96+
entry_a = MemoryEntry(
97+
content={"importance": 5, "id": "A"}, step=98, agent=mock_agent
98+
)
99+
# score = 1 - (100 - 99) = 0
100+
entry_b = MemoryEntry(
101+
content={"importance": 1, "id": "B"}, step=99, agent=mock_agent
102+
)
103+
# score = 4 - (100 - 90) = -6
104+
entry_c = MemoryEntry(
105+
content={"importance": 4, "id": "C"}, step=90, agent=mock_agent
106+
)
107+
# score = 4 - (100 - 95) = -1
108+
entry_d = MemoryEntry(
109+
content={"importance": 4, "id": "D"}, step=95, agent=mock_agent
110+
)
111+
112+
memory.memory_entries.extend([entry_a, entry_b, entry_c, entry_d])
113+
114+
# Retrieve top 3 (k=3)
115+
top_entries = memory.retrieve_top_k_entries(3)
116+
117+
# Expected order: A (3), B (0), D (-1)
118+
assert len(top_entries) == 3
119+
assert top_entries[0].content["id"] == "A"
120+
assert top_entries[1].content["id"] == "B"
121+
assert top_entries[2].content["id"] == "D"
122+
123+
# Entry C (score -6) should be omitted
124+
assert "C" not in [e.content["id"] for e in top_entries]

0 commit comments

Comments
 (0)