Skip to content

A JAX-based CuMind is a JAX-based RL framework inspired by Google DeepMind. It combines Monte Carlo Tree Search (MCTS) with a learned model to achieve superhuman performance in complex domains without prior knowledge of their rules.

License

Notifications You must be signed in to change notification settings

carletonai/CuMind

Repository files navigation

CuMind

A clean, elegant implementation of the MuZero in JAX.

Contributing

For guidelines on contributing to this project, please see the CONTRIBUTING document.

Features

  • JAX/Flax powered: Built on JAX for high performance and research flexibility
  • Minimal dependencies: JAX, Flax, Optax, Chex, NumPy, and Gymnasium
  • Clean architecture: Modular design with separate components
  • Type hints: Full type annotations for better code quality
  • Simple configuration: Single config class for all parameters
  • Fast development: Uses uv for project management

Installation

Get Started

pip install -U uv

Development Setup

To get started quickly, use CPU mode to avoid large downloads:

# Clone the repository
git clone git@github.com:carletonai/cumind.git
cd cumind

# Install dependencies
uv sync 

Quick Start

from cumind import Agent, Config
import jax

# Create configuration for CartPole (1D observations)
config = Config(
    action_space_size=2,
    observation_shape=(4,),  # 1D vector
    hidden_dim=64,
    num_simulations=25
)

# Initialize JAX random key
subkey = key(config.seed)

# Create agent
agent = Agent(config)

# Select action
action = agent.select_action(observation)

For image-based environments (like Atari):

# Create configuration for Atari (3D observations)
config = Config(
    action_space_size=4,
    observation_shape=(3, 84, 84),  # 3D image: channels, height, width
    hidden_dim=64,
    num_simulations=50
)

Development Commands

The project uses uv instead of traditional package managers. Here are the equivalent commands:

Code Quality

# Format code 
uv run ruff format

# Lint code 
uv run ruff check

# Fix linting issues in-line 
uv run ruff check --fix

# Type checking 
uv run mypy src/

Testing

# Run all tests 
uv run pytest

# Run tests with verbose output 
uv run pytest -v

# Run specific test file 
uv run pytest tests/test_mcts.py

Cleaning the Environment

To remove build artifacts, lock files, and cached dependencies:

# Remove uv lock file
rm uv.lock

# Clean uv's cache (removes downloaded wheels, etc.)
uv cache clean

# Remove build artifacts and temporary files
uv clean

# Remove pytest/ruff cache 
rm -rf .*_cache

License

This project is licensed under the MIT License. See the LICENSE file for details.

About

A JAX-based CuMind is a JAX-based RL framework inspired by Google DeepMind. It combines Monte Carlo Tree Search (MCTS) with a learned model to achieve superhuman performance in complex domains without prior knowledge of their rules.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Contributors 8