Skip to content

Commit ce2f0d5

Browse files
committed
implement VideoViz to record model runs in a video
1 parent 527c023 commit ce2f0d5

File tree

3 files changed

+227
-0
lines changed

3 files changed

+227
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,6 @@ dmypy.json
9292
# JS dependencies
9393
mesa/visualization/templates/external/
9494
mesa/visualization/templates/js/external/
95+
96+
# Video
97+
**/*.mp4
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Example of using VideoViz with the Schelling model."""
2+
3+
from mesa.examples.basic.schelling.model import Schelling
4+
from mesa.visualization.video_viz import (
5+
VideoViz,
6+
make_measure_component,
7+
make_space_component,
8+
)
9+
10+
# Create model
11+
model = Schelling(10, 10)
12+
13+
14+
def agent_portrayal(agent):
15+
"""Portray agents based on their type."""
16+
if agent is None:
17+
return {}
18+
19+
portrayal = {
20+
"color": "red" if agent.type == 0 else "blue",
21+
"size": 25,
22+
"marker": "s", # square marker
23+
}
24+
return portrayal
25+
26+
27+
# Create visualization with space and some metrics
28+
viz = VideoViz(
29+
model,
30+
[
31+
make_space_component(agent_portrayal=agent_portrayal, save_format="svg"),
32+
make_measure_component("happy", save_format="svg"),
33+
],
34+
title="Schelling's Segregation Model",
35+
)
36+
37+
# Record simulation
38+
if __name__ == "__main__":
39+
video_path = viz.record(steps=50, filepath="schelling.mp4")
40+
print(f"Video saved to: {video_path}")

mesa/visualization/video_viz.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""Video recording components for Mesa model visualization."""
2+
3+
import shutil
4+
from collections.abc import Callable, Sequence
5+
from pathlib import Path
6+
7+
import matplotlib.animation as animation
8+
import matplotlib.pyplot as plt
9+
import numpy as np
10+
11+
import mesa
12+
from mesa.visualization.matplotlib_renderer import (
13+
MatplotlibRenderer,
14+
MeasureRendererMatplotlib,
15+
SpaceRenderMatplotlib,
16+
)
17+
18+
19+
def make_space_component(
20+
agent_portrayal: Callable | None = None,
21+
propertylayer_portrayal: dict | None = None,
22+
post_process: Callable | None = None,
23+
**space_drawing_kwargs,
24+
):
25+
"""Create a Matplotlib-based space visualization component.
26+
27+
Args:
28+
agent_portrayal: Function to portray agents.
29+
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
30+
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks)
31+
backend: The backend to use for rendering the space. Can be "matplotlib" or "altair".
32+
space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See
33+
the functions for drawing the various spaces for further details.
34+
35+
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
36+
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
37+
38+
39+
Returns:
40+
SpaceRenderMatplotlib: A component for rendering the space.
41+
"""
42+
if agent_portrayal is None:
43+
44+
def agent_portrayal(a):
45+
return {}
46+
47+
return SpaceRenderMatplotlib(
48+
agent_portrayal,
49+
propertylayer_portrayal,
50+
post_process=post_process,
51+
**space_drawing_kwargs,
52+
)
53+
54+
55+
def make_measure_component(
56+
measure: Callable,
57+
**kwargs,
58+
):
59+
"""Create a plotting function for a specified measure.
60+
61+
Args:
62+
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
63+
kwargs: Additional keyword arguments to pass to the MeasureRendererMatplotlib constructor.
64+
65+
Returns:
66+
MeasureRendererMatplotlib: A component for rendering the measure.
67+
"""
68+
return MeasureRendererMatplotlib(
69+
measure,
70+
**kwargs,
71+
)
72+
73+
74+
class VideoViz:
75+
"""Create high-quality video recordings of model simulations."""
76+
77+
def __init__(
78+
self,
79+
model: mesa.Model,
80+
components: Sequence[MatplotlibRenderer],
81+
*,
82+
title: str | None = None,
83+
figsize: tuple[float, float] | None = None,
84+
grid: tuple[int, int] | None = None,
85+
):
86+
"""Initialize video visualization configuration.
87+
88+
Args:
89+
model: The model to simulate and record
90+
components: Sequence of component objects defining what to visualize
91+
title: Optional title for the video
92+
figsize: Optional figure size in inches (width, height)
93+
grid: Optional (rows, cols) for custom layout. Auto-calculated if None.
94+
"""
95+
# Check if FFmpeg is available
96+
if not shutil.which("ffmpeg"):
97+
raise RuntimeError(
98+
"FFmpeg not found. Please install FFmpeg to save animations:\n"
99+
" - macOS: brew install ffmpeg\n"
100+
" - Linux: sudo apt-get install ffmpeg\n"
101+
" - Windows: download from https://ffmpeg.org/download.html"
102+
)
103+
self.model = model
104+
self.components = components
105+
self.title = title
106+
self.figsize = figsize
107+
self.grid = grid or self._calculate_grid(len(components))
108+
109+
# Setup figure and axes
110+
self.fig, self.axes = self._setup_figure()
111+
112+
def record(
113+
self,
114+
*,
115+
steps: int,
116+
filepath: str | Path,
117+
dpi: int = 100,
118+
fps: int = 10,
119+
codec: str = "h264",
120+
bitrate: int = 2000,
121+
) -> Path:
122+
"""Record model simulation to video file.
123+
124+
Args:
125+
steps: Number of simulation steps to record
126+
filepath: Where to save the video file
127+
dpi: Resolution of the output video
128+
fps: Frames per second in the output video
129+
codec: Video codec to use
130+
bitrate: Video bitrate in kbps (default: 2000)
131+
132+
Returns:
133+
Path to the saved video file
134+
135+
Raises:
136+
RuntimeError: If FFmpeg is not installed
137+
"""
138+
filepath = Path(filepath)
139+
140+
def update(frame_num):
141+
# Update model state
142+
self.model.step()
143+
144+
# Render all visualization frames
145+
for component, ax in zip(self.components, self.axes):
146+
ax.clear()
147+
component.draw(self.model, ax)
148+
return self.axes
149+
150+
# Create and save animation
151+
anim = animation.FuncAnimation(
152+
self.fig, update, frames=steps, interval=1000 / fps, blit=False
153+
)
154+
155+
writer = animation.FFMpegWriter(
156+
fps=fps,
157+
codec=codec,
158+
bitrate=bitrate, # Now passing as integer
159+
)
160+
161+
anim.save(filepath, writer=writer, dpi=dpi)
162+
return filepath
163+
164+
def _calculate_grid(self, n_frames: int) -> tuple[int, int]:
165+
"""Calculate optimal grid layout for given number of frames."""
166+
cols = min(3, n_frames) # Max 3 columns
167+
rows = int(np.ceil(n_frames / cols))
168+
return (rows, cols)
169+
170+
def _setup_figure(self):
171+
"""Setup matplotlib figure and axes."""
172+
if not self.figsize:
173+
self.figsize = (5 * self.grid[1], 5 * self.grid[0])
174+
fig = plt.figure(figsize=self.figsize)
175+
axes = []
176+
177+
for i in range(len(self.components)):
178+
ax = fig.add_subplot(self.grid[0], self.grid[1], i + 1)
179+
axes.append(ax)
180+
181+
if self.title:
182+
fig.suptitle(self.title, fontsize=16)
183+
fig.tight_layout()
184+
return fig, axes

0 commit comments

Comments
 (0)