A clean, elegant implementation of the MuZero in JAX.
For guidelines on contributing to this project, please see the CONTRIBUTING document.
- JAX/Flax powered: Built on JAX for high performance and research flexibility
- Minimal dependencies: JAX, Flax, Optax, Chex, NumPy, and Gymnasium
- Clean architecture: Modular design with separate components
- Type hints: Full type annotations for better code quality
- Simple configuration: Single config class for all parameters
- Fast development: Uses
uvfor project management
pip install -U uvTo get started quickly, use CPU mode to avoid large downloads:
# Clone the repository
git clone git@github.com:carletonai/cumind.git
cd cumind
# Install dependencies
uv sync from cumind import Agent, Config
import jax
# Create configuration for CartPole (1D observations)
config = Config(
action_space_size=2,
observation_shape=(4,), # 1D vector
hidden_dim=64,
num_simulations=25
)
# Initialize JAX random key
subkey = key(config.seed)
# Create agent
agent = Agent(config)
# Select action
action = agent.select_action(observation)For image-based environments (like Atari):
# Create configuration for Atari (3D observations)
config = Config(
action_space_size=4,
observation_shape=(3, 84, 84), # 3D image: channels, height, width
hidden_dim=64,
num_simulations=50
)The project uses uv instead of traditional package managers. Here are the equivalent commands:
# Format code
uv run ruff format
# Lint code
uv run ruff check
# Fix linting issues in-line
uv run ruff check --fix
# Type checking
uv run mypy src/# Run all tests
uv run pytest
# Run tests with verbose output
uv run pytest -v
# Run specific test file
uv run pytest tests/test_mcts.pyTo remove build artifacts, lock files, and cached dependencies:
# Remove uv lock file
rm uv.lock
# Clean uv's cache (removes downloaded wheels, etc.)
uv cache clean
# Remove build artifacts and temporary files
uv clean
# Remove pytest/ruff cache
rm -rf .*_cacheThis project is licensed under the MIT License. See the LICENSE file for details.