Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,23 @@ pip install -r requirements.txt
Before training the models, generate the CoinRun dataset by running:

```bash
python generate_dataset.py --num_episodes 10000
python generate_dataset.py --num_episodes 10000 --env_name coinrun
```

See [here](https://github.com/openai/procgen?tab=readme-ov-file#environments) for more environments.

Note: this is a large dataset (around 100GB) and may take a while to generate.

To generate other datasets from the gym environment run:

```bash
python generate_dataset_gym.py --num_episodes 10000 --env_name Acrobot-v1
```

See [here](https://gym.openai.com/envs/#classic_control) for more environments.

Note: This project uses gym==0.25.2 for backwards compatibility. Newer versions of gym are not supported.

<h2 name="train" id="train">Quick Start 🚀 </h2>

Genie has three components: a [video tokenizer](models/tokenizer.py), a [latent action model](models/lam.py), and a [dynamics model](models/dynamics.py). Each of these components are trained separately, however, the dynamics model requires a pre-trained video tokenizer and latent action model.
Expand Down
29 changes: 23 additions & 6 deletions generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,42 @@

from dataclasses import dataclass
from pathlib import Path
import time
import logging

from gym3 import types_np
import numpy as np
from procgen import ProcgenGym3Env
import tyro

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class Args:
num_episodes: int = 10000
output_dir: str = "data/coinrun_episodes"
env_name: str = "coinrun"
min_episode_length: int = 50

output_dir: str = "data"

args = tyro.cli(Args)
output_dir = Path(args.output_dir)

output_dir = Path(f"{args.output_dir}/{args.env_name}_episodes")
output_dir.mkdir(parents=True, exist_ok=True)

# --- Generate episodes ---
i = 0
metadata = []
times= []
while i < args.num_episodes:
seed = np.random.randint(0, 10000)
env = ProcgenGym3Env(num=1, env_name="coinrun", start_level=seed)
dataseq = []

# --- Run episode ---
logger.info(f"Generating episode {i}...")
start_time = time.time()
for j in range(1000):
env.act(types_np.sample(env.ac_space, bshape=(env.num,)))
rew, obs, first = env.observe()
Expand All @@ -45,11 +54,19 @@ class Args:
episode_path = output_dir / f"episode_{i}.npy"
np.save(episode_path, episode_data.astype(np.uint8))
metadata.append({"path": str(episode_path), "length": len(dataseq)})
print(f"Episode {i} completed, length: {len(dataseq)}")
# log time per episode
times.append(time.time() - start_time)
logger.info(f"Episode {i} completed, length: {len(dataseq)}, time: {time.time() - start_time}")
i += 1

# Save metadata every 1000 episodes
if i % 1000 == 0:
np.save(output_dir / f"metadata_episodes_{i}.npy", metadata)
else:
print(f"Episode too short ({len(dataseq)}), resampling...")
logger.warning(f"Episode too short ({len(dataseq)}), resampling...")


# --- Save metadata ---
np.save(output_dir / "metadata.npy", metadata)
print(f"Dataset generated with {len(metadata)} valid episodes")
logger.info(f"Dataset generated with {len(metadata)} valid episodes")
logger.info(f"Average time per episode: {np.mean(times)}")
70 changes: 70 additions & 0 deletions generate_dataset_gym_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Generates a dataset from the gym environment.
Episodes are saved individually as memory-mapped files for efficient loading.
"""

from dataclasses import dataclass
from pathlib import Path

import gymnasium as gym
import numpy as np
import tyro
import time
import multiprocessing as mp

@dataclass
class Args:
num_episodes: int = 10000
env_name: str = "Acrobot-v1"
min_episode_length: int = 50
seed: int = 42


def generate_episode(args_tuple):
env_name, min_episode_length, seed, episode_idx, output_dir = args_tuple
env = gym.make(env_name, render_mode="rgb_array")
observation, info = env.reset(seed=seed + episode_idx)
dataseq = []
print(f"Episode {episode_idx} started")
for j in range(1000):
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
dataseq.append(env.render())
if terminated or truncated:
break
if len(dataseq) >= min_episode_length:
episode_data = np.stack(dataseq, axis=0)
episode_path = output_dir / f"episode_{episode_idx}.npy"
np.save(episode_path, episode_data.astype(np.uint8))
print(f"Episode {episode_idx} saved")

return {"path": str(episode_path), "length": len(dataseq)}
else:
return None

def main():
args = tyro.cli(Args)
output_dir = Path(f"data/{args.env_name}_episodes_{args.num_episodes}")
output_dir.mkdir(parents=True, exist_ok=True)

pool_args = [
(args.env_name, args.min_episode_length, args.seed, i, output_dir)
for i in range(args.num_episodes)
]

print(f"Number of processes: {mp.cpu_count()}")

with mp.Pool(processes=mp.cpu_count()) as pool:
results = pool.map(generate_episode, pool_args)

# Filter out None (episodes that were too short)
metadata = [r for r in results if r is not None]
np.save(output_dir / "metadata.npy", metadata)
print(f"Dataset generated with {len(metadata)} valid episodes, saving to {output_dir}")


if __name__ == '__main__':
start_time = time.time()
main()
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
73 changes: 73 additions & 0 deletions generate_dataset_gym_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Generates a dataset from the gym environment.
Episodes are saved individually as memory-mapped files for efficient loading.
"""

from dataclasses import dataclass
from pathlib import Path

import gymnasium as gym
import numpy as np
import tyro
import time

import crafter


@dataclass
class Args:
num_episodes: int = 10000
env_name: str = "Acrobot-v1"
min_episode_length: int = 50
seed: int = 42


def main():
args = tyro.cli(Args)
output_dir = Path(f"data/{args.env_name}_episodes_{args.num_episodes}")
output_dir.mkdir(parents=True, exist_ok=True)

# --- Generate episodes ---
i = 0
metadata = []
time_per_episode = []
while i < args.num_episodes:
time_start_episode = time.time()
env = gym.make(args.env_name, render_mode="rgb_array")
observation, info = env.reset(seed=args.seed)
dataseq = []

# --- Run episode ---
for j in range(1000):
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
dataseq.append(env.render())
if terminated or truncated:
break

# --- Save episode ---
if len(dataseq) >= args.min_episode_length:
episode_data = np.stack(dataseq, axis=0)
episode_path = output_dir / f"episode_{i}.npy"
np.save(episode_path, episode_data.astype(np.uint8))
time_per_episode.append(time.time() - time_start_episode)
metadata.append({"path": str(episode_path), "length": len(dataseq)})
if i % 5 == 0:
print(f"Episode {i} completed, length: {len(dataseq)}, time: {time_per_episode[-1]} seconds")
else:
print(f"Episode {i} completed, length: {len(dataseq)}")
i += 1
else:
print(f"Episode too short ({len(dataseq)}), resampling...")

# --- Save metadata ---
np.save(output_dir / "metadata.npy", metadata)
print(f"Dataset generated with {len(metadata)} valid episodes, saving to {output_dir}")
print(f"Average time per episode: {np.mean(time_per_episode)} seconds")


if __name__ == '__main__':
start_time = time.time()
main()
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@ optax>=0.2.3
procgen>=0.10.7
torch>=2.0.1
tyro>=0.8.5
wandb>=0.17.4
wandb>=0.17.4
gym==0.25.2 # Last version that supports backward compatibility https://github.com/openai/gym/releases/tag/0.26.0
# To support Box2D environments
swig==4.3.1
gymnasium[box2d]