diff --git a/README.md b/README.md
index e9503a2..16b6a3c 100644
--- a/README.md
+++ b/README.md
@@ -22,11 +22,23 @@ pip install -r requirements.txt
Before training the models, generate the CoinRun dataset by running:
```bash
-python generate_dataset.py --num_episodes 10000
+python generate_dataset.py --num_episodes 10000 --env_name coinrun
```
+See [here](https://github.com/openai/procgen?tab=readme-ov-file#environments) for more environments.
+
Note: this is a large dataset (around 100GB) and may take a while to generate.
+To generate other datasets from the gym environment run:
+
+```bash
+python generate_dataset_gym.py --num_episodes 10000 --env_name Acrobot-v1
+```
+
+See [here](https://gym.openai.com/envs/#classic_control) for more environments.
+
+Note: This project uses gym==0.25.2 for backwards compatibility. Newer versions of gym are not supported.
+
Quick Start 🚀
Genie has three components: a [video tokenizer](models/tokenizer.py), a [latent action model](models/lam.py), and a [dynamics model](models/dynamics.py). Each of these components are trained separately, however, the dynamics model requires a pre-trained video tokenizer and latent action model.
diff --git a/generate_dataset.py b/generate_dataset.py
index a67c424..5500bf7 100644
--- a/generate_dataset.py
+++ b/generate_dataset.py
@@ -5,33 +5,42 @@
from dataclasses import dataclass
from pathlib import Path
+import time
+import logging
from gym3 import types_np
import numpy as np
from procgen import ProcgenGym3Env
import tyro
+# Set up logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
@dataclass
class Args:
num_episodes: int = 10000
- output_dir: str = "data/coinrun_episodes"
+ env_name: str = "coinrun"
min_episode_length: int = 50
-
+ output_dir: str = "data"
args = tyro.cli(Args)
-output_dir = Path(args.output_dir)
+
+output_dir = Path(f"{args.output_dir}/{args.env_name}_episodes")
output_dir.mkdir(parents=True, exist_ok=True)
# --- Generate episodes ---
i = 0
metadata = []
+times= []
while i < args.num_episodes:
seed = np.random.randint(0, 10000)
env = ProcgenGym3Env(num=1, env_name="coinrun", start_level=seed)
dataseq = []
# --- Run episode ---
+ logger.info(f"Generating episode {i}...")
+ start_time = time.time()
for j in range(1000):
env.act(types_np.sample(env.ac_space, bshape=(env.num,)))
rew, obs, first = env.observe()
@@ -45,11 +54,19 @@ class Args:
episode_path = output_dir / f"episode_{i}.npy"
np.save(episode_path, episode_data.astype(np.uint8))
metadata.append({"path": str(episode_path), "length": len(dataseq)})
- print(f"Episode {i} completed, length: {len(dataseq)}")
+ # log time per episode
+ times.append(time.time() - start_time)
+ logger.info(f"Episode {i} completed, length: {len(dataseq)}, time: {time.time() - start_time}")
i += 1
+
+ # Save metadata every 1000 episodes
+ if i % 1000 == 0:
+ np.save(output_dir / f"metadata_episodes_{i}.npy", metadata)
else:
- print(f"Episode too short ({len(dataseq)}), resampling...")
+ logger.warning(f"Episode too short ({len(dataseq)}), resampling...")
+
# --- Save metadata ---
np.save(output_dir / "metadata.npy", metadata)
-print(f"Dataset generated with {len(metadata)} valid episodes")
+logger.info(f"Dataset generated with {len(metadata)} valid episodes")
+logger.info(f"Average time per episode: {np.mean(times)}")
diff --git a/generate_dataset_gym_multi.py b/generate_dataset_gym_multi.py
new file mode 100644
index 0000000..5134245
--- /dev/null
+++ b/generate_dataset_gym_multi.py
@@ -0,0 +1,70 @@
+"""
+Generates a dataset from the gym environment.
+Episodes are saved individually as memory-mapped files for efficient loading.
+"""
+
+from dataclasses import dataclass
+from pathlib import Path
+
+import gymnasium as gym
+import numpy as np
+import tyro
+import time
+import multiprocessing as mp
+
+@dataclass
+class Args:
+ num_episodes: int = 10000
+ env_name: str = "Acrobot-v1"
+ min_episode_length: int = 50
+ seed: int = 42
+
+
+def generate_episode(args_tuple):
+ env_name, min_episode_length, seed, episode_idx, output_dir = args_tuple
+ env = gym.make(env_name, render_mode="rgb_array")
+ observation, info = env.reset(seed=seed + episode_idx)
+ dataseq = []
+ print(f"Episode {episode_idx} started")
+ for j in range(1000):
+ action = env.action_space.sample()
+ observation, reward, terminated, truncated, info = env.step(action)
+ dataseq.append(env.render())
+ if terminated or truncated:
+ break
+ if len(dataseq) >= min_episode_length:
+ episode_data = np.stack(dataseq, axis=0)
+ episode_path = output_dir / f"episode_{episode_idx}.npy"
+ np.save(episode_path, episode_data.astype(np.uint8))
+ print(f"Episode {episode_idx} saved")
+
+ return {"path": str(episode_path), "length": len(dataseq)}
+ else:
+ return None
+
+def main():
+ args = tyro.cli(Args)
+ output_dir = Path(f"data/{args.env_name}_episodes_{args.num_episodes}")
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ pool_args = [
+ (args.env_name, args.min_episode_length, args.seed, i, output_dir)
+ for i in range(args.num_episodes)
+ ]
+
+ print(f"Number of processes: {mp.cpu_count()}")
+
+ with mp.Pool(processes=mp.cpu_count()) as pool:
+ results = pool.map(generate_episode, pool_args)
+
+ # Filter out None (episodes that were too short)
+ metadata = [r for r in results if r is not None]
+ np.save(output_dir / "metadata.npy", metadata)
+ print(f"Dataset generated with {len(metadata)} valid episodes, saving to {output_dir}")
+
+
+if __name__ == '__main__':
+ start_time = time.time()
+ main()
+ end_time = time.time()
+ print(f"Time taken: {end_time - start_time} seconds")
diff --git a/generate_dataset_gym_single.py b/generate_dataset_gym_single.py
new file mode 100644
index 0000000..b473e5b
--- /dev/null
+++ b/generate_dataset_gym_single.py
@@ -0,0 +1,73 @@
+"""
+Generates a dataset from the gym environment.
+Episodes are saved individually as memory-mapped files for efficient loading.
+"""
+
+from dataclasses import dataclass
+from pathlib import Path
+
+import gymnasium as gym
+import numpy as np
+import tyro
+import time
+
+import crafter
+
+
+@dataclass
+class Args:
+ num_episodes: int = 10000
+ env_name: str = "Acrobot-v1"
+ min_episode_length: int = 50
+ seed: int = 42
+
+
+def main():
+ args = tyro.cli(Args)
+ output_dir = Path(f"data/{args.env_name}_episodes_{args.num_episodes}")
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # --- Generate episodes ---
+ i = 0
+ metadata = []
+ time_per_episode = []
+ while i < args.num_episodes:
+ time_start_episode = time.time()
+ env = gym.make(args.env_name, render_mode="rgb_array")
+ observation, info = env.reset(seed=args.seed)
+ dataseq = []
+
+ # --- Run episode ---
+ for j in range(1000):
+ action = env.action_space.sample()
+ observation, reward, terminated, truncated, info = env.step(action)
+ dataseq.append(env.render())
+ if terminated or truncated:
+ break
+
+ # --- Save episode ---
+ if len(dataseq) >= args.min_episode_length:
+ episode_data = np.stack(dataseq, axis=0)
+ episode_path = output_dir / f"episode_{i}.npy"
+ np.save(episode_path, episode_data.astype(np.uint8))
+ time_per_episode.append(time.time() - time_start_episode)
+ metadata.append({"path": str(episode_path), "length": len(dataseq)})
+ if i % 5 == 0:
+ print(f"Episode {i} completed, length: {len(dataseq)}, time: {time_per_episode[-1]} seconds")
+ else:
+ print(f"Episode {i} completed, length: {len(dataseq)}")
+ i += 1
+ else:
+ print(f"Episode too short ({len(dataseq)}), resampling...")
+
+ # --- Save metadata ---
+ np.save(output_dir / "metadata.npy", metadata)
+ print(f"Dataset generated with {len(metadata)} valid episodes, saving to {output_dir}")
+ print(f"Average time per episode: {np.mean(time_per_episode)} seconds")
+
+
+if __name__ == '__main__':
+ start_time = time.time()
+ main()
+ end_time = time.time()
+ print(f"Time taken: {end_time - start_time} seconds")
diff --git a/requirements.txt b/requirements.txt
index 8699240..db6f11b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,4 +6,8 @@ optax>=0.2.3
procgen>=0.10.7
torch>=2.0.1
tyro>=0.8.5
-wandb>=0.17.4
\ No newline at end of file
+wandb>=0.17.4
+gym==0.25.2 # Last version that supports backward compatibility https://github.com/openai/gym/releases/tag/0.26.0
+# To support Box2D environments
+swig==4.3.1
+gymnasium[box2d]
\ No newline at end of file