From 97c8156e205da57a8c5b4fa662f5e5b4ac589322 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Tue, 12 Aug 2025 12:40:38 +0200 Subject: [PATCH 1/3] feat: add darkness filtering capability to dataloader - Introduced a DarknessFilter class to filter out sequences with images that are too dark based on a specified threshold. - Added a darkness_threshold parameter to Args and the get_dataloader function to support this feature. --- train_dynamics.py | 2 ++ utils/dataloader.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/train_dynamics.py b/train_dynamics.py index c9707f0..fc7257e 100644 --- a/train_dynamics.py +++ b/train_dynamics.py @@ -44,6 +44,7 @@ class Args: ) warmup_steps: int = 5000 lr_schedule: str = "wsd" # supported options: wsd, cos + darkness_threshold: float = 0.0 # Tokenizer tokenizer_dim: int = 512 tokenizer_ffn_dim: int = 2048 @@ -319,6 +320,7 @@ def loss_fn(model: Genie) -> tuple[jax.Array, tuple[jax.Array, dict]]: # The dataloader shards the dataset across all processes args.batch_size, *image_shape, + darkness_threshold=args.darkness_threshold, num_workers=8, prefetch_buffer_size=1, seed=args.seed, diff --git a/utils/dataloader.py b/utils/dataloader.py index 448e38a..39f86f1 100644 --- a/utils/dataloader.py +++ b/utils/dataloader.py @@ -96,6 +96,38 @@ def random_map(self, element: dict, rng: np.random.Generator) -> Any: return seq +class DarknessFilter(grain.transforms.Filter): + """ + A Grain Filter that filters out sequences with images that are too dark. + """ + + def __init__(self, darkness_threshold: float): + """Initializes the filter with darkness threshold.""" + self.darkness_threshold = darkness_threshold + + def filter(self, element: Any) -> bool: + """ + Filters sequences based on darkness. + + Args: + element: A NumPy array representing a processed video sequence. + + Returns: + True if the sequence is not too dark, False otherwise. + """ + # Convert the RGB image to grayscale using numpy + element_greyscale = np.dot(element[...,:3], [0.2989, 0.5870, 0.1140]) + average_brightness = np.mean(element_greyscale) + if average_brightness < self.darkness_threshold: + print( + f"Filtering out sequence with average brightness {average_brightness}, " + f"which is below the darkness threshold {self.darkness_threshold}." + ) + return False + + return True + + def get_dataloader( array_record_paths: list[str], seq_len: int, @@ -103,6 +135,7 @@ def get_dataloader( image_h: int, image_w: int, image_c: int, + darkness_threshold: float = 0., num_workers: int = 1, prefetch_buffer_size: int = 1, seed: int = 42, @@ -139,6 +172,9 @@ def get_dataloader( ProcessEpisodeAndSlice( seq_len=seq_len, image_h=image_h, image_w=image_w, image_c=image_c ), + DarknessFilter( + darkness_threshold=darkness_threshold + ), grain.transforms.Batch(batch_size=per_process_batch_size, drop_remainder=True), ] From 0a21fe2df38f4901998332b463dcfb50b4971547 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Tue, 12 Aug 2025 12:49:46 +0200 Subject: [PATCH 2/3] added darkness threshold to sample.py --- sample.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sample.py b/sample.py index 4d5e327..5889163 100644 --- a/sample.py +++ b/sample.py @@ -59,6 +59,8 @@ class Args: param_dtype = jnp.float32 dtype = jnp.bfloat16 use_flash_attention: bool = True + # Additional parameters + darkness_threshold: float = 0.0 args = tyro.cli(Args) @@ -195,6 +197,7 @@ def _autoreg_sample(genie, rng, video_batch_BSHWC, action_batch_E): num_workers=0, prefetch_buffer_size=1, seed=args.seed, + darkness_threshold=args.darkness_threshold, ) dataloader = iter(dataloader) video_batch_BSHWC = next(dataloader) From f7f7a878f074ca9984ff3f44a7d7092f8cbb9b63 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Tue, 2 Sep 2025 11:38:50 +0200 Subject: [PATCH 3/3] add darkness threshold to tokenizer and lam; refactor magic number out of the function --- train_lam.py | 2 ++ train_tokenizer.py | 2 ++ utils/dataloader.py | 3 ++- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/train_lam.py b/train_lam.py index 388f355..834850a 100644 --- a/train_lam.py +++ b/train_lam.py @@ -46,6 +46,7 @@ class Args: warmup_steps: int = 5000 lr_schedule: str = "wsd" # supported options: wsd, cos vq_reset_thresh: int = 50 + darkness_threshold: float = 0.0 # LAM model_dim: int = 512 ffn_dim: int = 2048 @@ -297,6 +298,7 @@ def loss_fn( # The dataloader shards the dataset across all processes args.batch_size, *image_shape, + darkness_threshold=args.darkness_threshold, num_workers=8, prefetch_buffer_size=1, seed=args.seed, diff --git a/train_tokenizer.py b/train_tokenizer.py index 5e71b2f..f3c36cb 100644 --- a/train_tokenizer.py +++ b/train_tokenizer.py @@ -45,6 +45,7 @@ class Args: ) lr_schedule: str = "wsd" # supported options: wsd, cos warmup_steps: int = 10000 + darkness_threshold: float = 0.0 # Tokenizer model_dim: int = 512 ffn_dim: int = 2048 @@ -289,6 +290,7 @@ def loss_fn(model: TokenizerVQVAE) -> tuple[jax.Array, tuple[jax.Array, dict]]: # The dataloader shards the dataset across all processes args.batch_size, *image_shape, + darkness_threshold=args.darkness_threshold, num_workers=8, prefetch_buffer_size=1, seed=args.seed, diff --git a/utils/dataloader.py b/utils/dataloader.py index 39f86f1..b6dc4ef 100644 --- a/utils/dataloader.py +++ b/utils/dataloader.py @@ -4,6 +4,7 @@ from typing import Any import pickle +RGB_TO_GRAYSCALE_WEIGHTS = np.array([0.2989, 0.5870, 0.1140]) class EpisodeLengthFilter(grain.transforms.Filter): """ @@ -116,7 +117,7 @@ def filter(self, element: Any) -> bool: True if the sequence is not too dark, False otherwise. """ # Convert the RGB image to grayscale using numpy - element_greyscale = np.dot(element[...,:3], [0.2989, 0.5870, 0.1140]) + element_greyscale = np.dot(element[...,:3], RGB_TO_GRAYSCALE_WEIGHTS) average_brightness = np.mean(element_greyscale) if average_brightness < self.darkness_threshold: print(