diff --git a/sample.py b/sample.py index 4d5e327..5889163 100644 --- a/sample.py +++ b/sample.py @@ -59,6 +59,8 @@ class Args: param_dtype = jnp.float32 dtype = jnp.bfloat16 use_flash_attention: bool = True + # Additional parameters + darkness_threshold: float = 0.0 args = tyro.cli(Args) @@ -195,6 +197,7 @@ def _autoreg_sample(genie, rng, video_batch_BSHWC, action_batch_E): num_workers=0, prefetch_buffer_size=1, seed=args.seed, + darkness_threshold=args.darkness_threshold, ) dataloader = iter(dataloader) video_batch_BSHWC = next(dataloader) diff --git a/train_dynamics.py b/train_dynamics.py index c9707f0..fc7257e 100644 --- a/train_dynamics.py +++ b/train_dynamics.py @@ -44,6 +44,7 @@ class Args: ) warmup_steps: int = 5000 lr_schedule: str = "wsd" # supported options: wsd, cos + darkness_threshold: float = 0.0 # Tokenizer tokenizer_dim: int = 512 tokenizer_ffn_dim: int = 2048 @@ -319,6 +320,7 @@ def loss_fn(model: Genie) -> tuple[jax.Array, tuple[jax.Array, dict]]: # The dataloader shards the dataset across all processes args.batch_size, *image_shape, + darkness_threshold=args.darkness_threshold, num_workers=8, prefetch_buffer_size=1, seed=args.seed, diff --git a/train_lam.py b/train_lam.py index 388f355..834850a 100644 --- a/train_lam.py +++ b/train_lam.py @@ -46,6 +46,7 @@ class Args: warmup_steps: int = 5000 lr_schedule: str = "wsd" # supported options: wsd, cos vq_reset_thresh: int = 50 + darkness_threshold: float = 0.0 # LAM model_dim: int = 512 ffn_dim: int = 2048 @@ -297,6 +298,7 @@ def loss_fn( # The dataloader shards the dataset across all processes args.batch_size, *image_shape, + darkness_threshold=args.darkness_threshold, num_workers=8, prefetch_buffer_size=1, seed=args.seed, diff --git a/train_tokenizer.py b/train_tokenizer.py index 5e71b2f..f3c36cb 100644 --- a/train_tokenizer.py +++ b/train_tokenizer.py @@ -45,6 +45,7 @@ class Args: ) lr_schedule: str = "wsd" # supported options: wsd, cos warmup_steps: int = 10000 + darkness_threshold: float = 0.0 # Tokenizer model_dim: int = 512 ffn_dim: int = 2048 @@ -289,6 +290,7 @@ def loss_fn(model: TokenizerVQVAE) -> tuple[jax.Array, tuple[jax.Array, dict]]: # The dataloader shards the dataset across all processes args.batch_size, *image_shape, + darkness_threshold=args.darkness_threshold, num_workers=8, prefetch_buffer_size=1, seed=args.seed, diff --git a/utils/dataloader.py b/utils/dataloader.py index 448e38a..b6dc4ef 100644 --- a/utils/dataloader.py +++ b/utils/dataloader.py @@ -4,6 +4,7 @@ from typing import Any import pickle +RGB_TO_GRAYSCALE_WEIGHTS = np.array([0.2989, 0.5870, 0.1140]) class EpisodeLengthFilter(grain.transforms.Filter): """ @@ -96,6 +97,38 @@ def random_map(self, element: dict, rng: np.random.Generator) -> Any: return seq +class DarknessFilter(grain.transforms.Filter): + """ + A Grain Filter that filters out sequences with images that are too dark. + """ + + def __init__(self, darkness_threshold: float): + """Initializes the filter with darkness threshold.""" + self.darkness_threshold = darkness_threshold + + def filter(self, element: Any) -> bool: + """ + Filters sequences based on darkness. + + Args: + element: A NumPy array representing a processed video sequence. + + Returns: + True if the sequence is not too dark, False otherwise. + """ + # Convert the RGB image to grayscale using numpy + element_greyscale = np.dot(element[...,:3], RGB_TO_GRAYSCALE_WEIGHTS) + average_brightness = np.mean(element_greyscale) + if average_brightness < self.darkness_threshold: + print( + f"Filtering out sequence with average brightness {average_brightness}, " + f"which is below the darkness threshold {self.darkness_threshold}." + ) + return False + + return True + + def get_dataloader( array_record_paths: list[str], seq_len: int, @@ -103,6 +136,7 @@ def get_dataloader( image_h: int, image_w: int, image_c: int, + darkness_threshold: float = 0., num_workers: int = 1, prefetch_buffer_size: int = 1, seed: int = 42, @@ -139,6 +173,9 @@ def get_dataloader( ProcessEpisodeAndSlice( seq_len=seq_len, image_h=image_h, image_w=image_w, image_c=image_c ), + DarknessFilter( + darkness_threshold=darkness_threshold + ), grain.transforms.Batch(batch_size=per_process_batch_size, drop_remainder=True), ]