diff --git a/README.md b/README.md index e9503a2..16b6a3c 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,23 @@ pip install -r requirements.txt Before training the models, generate the CoinRun dataset by running: ```bash -python generate_dataset.py --num_episodes 10000 +python generate_dataset.py --num_episodes 10000 --env_name coinrun ``` +See [here](https://github.com/openai/procgen?tab=readme-ov-file#environments) for more environments. + Note: this is a large dataset (around 100GB) and may take a while to generate. +To generate other datasets from the gym environment run: + +```bash +python generate_dataset_gym.py --num_episodes 10000 --env_name Acrobot-v1 +``` + +See [here](https://gym.openai.com/envs/#classic_control) for more environments. + +Note: This project uses gym==0.25.2 for backwards compatibility. Newer versions of gym are not supported. +

Quick Start 🚀

Genie has three components: a [video tokenizer](models/tokenizer.py), a [latent action model](models/lam.py), and a [dynamics model](models/dynamics.py). Each of these components are trained separately, however, the dynamics model requires a pre-trained video tokenizer and latent action model. diff --git a/generate_dataset.py b/generate_dataset.py index a67c424..5500bf7 100644 --- a/generate_dataset.py +++ b/generate_dataset.py @@ -5,33 +5,42 @@ from dataclasses import dataclass from pathlib import Path +import time +import logging from gym3 import types_np import numpy as np from procgen import ProcgenGym3Env import tyro +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) @dataclass class Args: num_episodes: int = 10000 - output_dir: str = "data/coinrun_episodes" + env_name: str = "coinrun" min_episode_length: int = 50 - + output_dir: str = "data" args = tyro.cli(Args) -output_dir = Path(args.output_dir) + +output_dir = Path(f"{args.output_dir}/{args.env_name}_episodes") output_dir.mkdir(parents=True, exist_ok=True) # --- Generate episodes --- i = 0 metadata = [] +times= [] while i < args.num_episodes: seed = np.random.randint(0, 10000) env = ProcgenGym3Env(num=1, env_name="coinrun", start_level=seed) dataseq = [] # --- Run episode --- + logger.info(f"Generating episode {i}...") + start_time = time.time() for j in range(1000): env.act(types_np.sample(env.ac_space, bshape=(env.num,))) rew, obs, first = env.observe() @@ -45,11 +54,19 @@ class Args: episode_path = output_dir / f"episode_{i}.npy" np.save(episode_path, episode_data.astype(np.uint8)) metadata.append({"path": str(episode_path), "length": len(dataseq)}) - print(f"Episode {i} completed, length: {len(dataseq)}") + # log time per episode + times.append(time.time() - start_time) + logger.info(f"Episode {i} completed, length: {len(dataseq)}, time: {time.time() - start_time}") i += 1 + + # Save metadata every 1000 episodes + if i % 1000 == 0: + np.save(output_dir / f"metadata_episodes_{i}.npy", metadata) else: - print(f"Episode too short ({len(dataseq)}), resampling...") + logger.warning(f"Episode too short ({len(dataseq)}), resampling...") + # --- Save metadata --- np.save(output_dir / "metadata.npy", metadata) -print(f"Dataset generated with {len(metadata)} valid episodes") +logger.info(f"Dataset generated with {len(metadata)} valid episodes") +logger.info(f"Average time per episode: {np.mean(times)}") diff --git a/generate_dataset_gym_multi.py b/generate_dataset_gym_multi.py new file mode 100644 index 0000000..5134245 --- /dev/null +++ b/generate_dataset_gym_multi.py @@ -0,0 +1,70 @@ +""" +Generates a dataset from the gym environment. +Episodes are saved individually as memory-mapped files for efficient loading. +""" + +from dataclasses import dataclass +from pathlib import Path + +import gymnasium as gym +import numpy as np +import tyro +import time +import multiprocessing as mp + +@dataclass +class Args: + num_episodes: int = 10000 + env_name: str = "Acrobot-v1" + min_episode_length: int = 50 + seed: int = 42 + + +def generate_episode(args_tuple): + env_name, min_episode_length, seed, episode_idx, output_dir = args_tuple + env = gym.make(env_name, render_mode="rgb_array") + observation, info = env.reset(seed=seed + episode_idx) + dataseq = [] + print(f"Episode {episode_idx} started") + for j in range(1000): + action = env.action_space.sample() + observation, reward, terminated, truncated, info = env.step(action) + dataseq.append(env.render()) + if terminated or truncated: + break + if len(dataseq) >= min_episode_length: + episode_data = np.stack(dataseq, axis=0) + episode_path = output_dir / f"episode_{episode_idx}.npy" + np.save(episode_path, episode_data.astype(np.uint8)) + print(f"Episode {episode_idx} saved") + + return {"path": str(episode_path), "length": len(dataseq)} + else: + return None + +def main(): + args = tyro.cli(Args) + output_dir = Path(f"data/{args.env_name}_episodes_{args.num_episodes}") + output_dir.mkdir(parents=True, exist_ok=True) + + pool_args = [ + (args.env_name, args.min_episode_length, args.seed, i, output_dir) + for i in range(args.num_episodes) + ] + + print(f"Number of processes: {mp.cpu_count()}") + + with mp.Pool(processes=mp.cpu_count()) as pool: + results = pool.map(generate_episode, pool_args) + + # Filter out None (episodes that were too short) + metadata = [r for r in results if r is not None] + np.save(output_dir / "metadata.npy", metadata) + print(f"Dataset generated with {len(metadata)} valid episodes, saving to {output_dir}") + + +if __name__ == '__main__': + start_time = time.time() + main() + end_time = time.time() + print(f"Time taken: {end_time - start_time} seconds") diff --git a/generate_dataset_gym_single.py b/generate_dataset_gym_single.py new file mode 100644 index 0000000..b473e5b --- /dev/null +++ b/generate_dataset_gym_single.py @@ -0,0 +1,73 @@ +""" +Generates a dataset from the gym environment. +Episodes are saved individually as memory-mapped files for efficient loading. +""" + +from dataclasses import dataclass +from pathlib import Path + +import gymnasium as gym +import numpy as np +import tyro +import time + +import crafter + + +@dataclass +class Args: + num_episodes: int = 10000 + env_name: str = "Acrobot-v1" + min_episode_length: int = 50 + seed: int = 42 + + +def main(): + args = tyro.cli(Args) + output_dir = Path(f"data/{args.env_name}_episodes_{args.num_episodes}") + output_dir.mkdir(parents=True, exist_ok=True) + + # --- Generate episodes --- + i = 0 + metadata = [] + time_per_episode = [] + while i < args.num_episodes: + time_start_episode = time.time() + env = gym.make(args.env_name, render_mode="rgb_array") + observation, info = env.reset(seed=args.seed) + dataseq = [] + + # --- Run episode --- + for j in range(1000): + action = env.action_space.sample() + observation, reward, terminated, truncated, info = env.step(action) + dataseq.append(env.render()) + if terminated or truncated: + break + + # --- Save episode --- + if len(dataseq) >= args.min_episode_length: + episode_data = np.stack(dataseq, axis=0) + episode_path = output_dir / f"episode_{i}.npy" + np.save(episode_path, episode_data.astype(np.uint8)) + time_per_episode.append(time.time() - time_start_episode) + metadata.append({"path": str(episode_path), "length": len(dataseq)}) + if i % 5 == 0: + print(f"Episode {i} completed, length: {len(dataseq)}, time: {time_per_episode[-1]} seconds") + else: + print(f"Episode {i} completed, length: {len(dataseq)}") + i += 1 + else: + print(f"Episode too short ({len(dataseq)}), resampling...") + + # --- Save metadata --- + np.save(output_dir / "metadata.npy", metadata) + print(f"Dataset generated with {len(metadata)} valid episodes, saving to {output_dir}") + print(f"Average time per episode: {np.mean(time_per_episode)} seconds") + + +if __name__ == '__main__': + start_time = time.time() + main() + end_time = time.time() + print(f"Time taken: {end_time - start_time} seconds") diff --git a/requirements.txt b/requirements.txt index 8699240..db6f11b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,8 @@ optax>=0.2.3 procgen>=0.10.7 torch>=2.0.1 tyro>=0.8.5 -wandb>=0.17.4 \ No newline at end of file +wandb>=0.17.4 +gym==0.25.2 # Last version that supports backward compatibility https://github.com/openai/gym/releases/tag/0.26.0 +# To support Box2D environments +swig==4.3.1 +gymnasium[box2d] \ No newline at end of file