Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class Args:
param_dtype = jnp.float32
dtype = jnp.bfloat16
use_flash_attention: bool = True
# Additional parameters
darkness_threshold: float = 0.0


args = tyro.cli(Args)
Expand Down Expand Up @@ -195,6 +197,7 @@ def _autoreg_sample(genie, rng, video_batch_BSHWC, action_batch_E):
num_workers=0,
prefetch_buffer_size=1,
seed=args.seed,
darkness_threshold=args.darkness_threshold,
)
dataloader = iter(dataloader)
video_batch_BSHWC = next(dataloader)
Expand Down
2 changes: 2 additions & 0 deletions train_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Args:
)
warmup_steps: int = 5000
lr_schedule: str = "wsd" # supported options: wsd, cos
darkness_threshold: float = 0.0
# Tokenizer
tokenizer_dim: int = 512
tokenizer_ffn_dim: int = 2048
Expand Down Expand Up @@ -319,6 +320,7 @@ def loss_fn(model: Genie) -> tuple[jax.Array, tuple[jax.Array, dict]]:
# The dataloader shards the dataset across all processes
args.batch_size,
*image_shape,
darkness_threshold=args.darkness_threshold,
num_workers=8,
prefetch_buffer_size=1,
seed=args.seed,
Expand Down
2 changes: 2 additions & 0 deletions train_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Args:
warmup_steps: int = 5000
lr_schedule: str = "wsd" # supported options: wsd, cos
vq_reset_thresh: int = 50
darkness_threshold: float = 0.0
# LAM
model_dim: int = 512
ffn_dim: int = 2048
Expand Down Expand Up @@ -297,6 +298,7 @@ def loss_fn(
# The dataloader shards the dataset across all processes
args.batch_size,
*image_shape,
darkness_threshold=args.darkness_threshold,
num_workers=8,
prefetch_buffer_size=1,
seed=args.seed,
Expand Down
2 changes: 2 additions & 0 deletions train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Args:
)
lr_schedule: str = "wsd" # supported options: wsd, cos
warmup_steps: int = 10000
darkness_threshold: float = 0.0
# Tokenizer
model_dim: int = 512
ffn_dim: int = 2048
Expand Down Expand Up @@ -289,6 +290,7 @@ def loss_fn(model: TokenizerVQVAE) -> tuple[jax.Array, tuple[jax.Array, dict]]:
# The dataloader shards the dataset across all processes
args.batch_size,
*image_shape,
darkness_threshold=args.darkness_threshold,
num_workers=8,
prefetch_buffer_size=1,
seed=args.seed,
Expand Down
37 changes: 37 additions & 0 deletions utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any
import pickle

RGB_TO_GRAYSCALE_WEIGHTS = np.array([0.2989, 0.5870, 0.1140])

class EpisodeLengthFilter(grain.transforms.Filter):
"""
Expand Down Expand Up @@ -96,13 +97,46 @@ def random_map(self, element: dict, rng: np.random.Generator) -> Any:
return seq


class DarknessFilter(grain.transforms.Filter):
"""
A Grain Filter that filters out sequences with images that are too dark.
"""

def __init__(self, darkness_threshold: float):
"""Initializes the filter with darkness threshold."""
self.darkness_threshold = darkness_threshold

def filter(self, element: Any) -> bool:
"""
Filters sequences based on darkness.

Args:
element: A NumPy array representing a processed video sequence.

Returns:
True if the sequence is not too dark, False otherwise.
"""
# Convert the RGB image to grayscale using numpy
element_greyscale = np.dot(element[...,:3], RGB_TO_GRAYSCALE_WEIGHTS)
average_brightness = np.mean(element_greyscale)
if average_brightness < self.darkness_threshold:
print(
Copy link

Copilot AI Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using print() statements for logging in production code is not recommended. Consider using a proper logging framework like Python's logging module to allow for better control over log levels and output destinations.

Suggested change
print(
logging.info(

Copilot uses AI. Check for mistakes.
f"Filtering out sequence with average brightness {average_brightness}, "
f"which is below the darkness threshold {self.darkness_threshold}."
)
return False

return True


def get_dataloader(
array_record_paths: list[str],
seq_len: int,
global_batch_size: int,
image_h: int,
image_w: int,
image_c: int,
darkness_threshold: float = 0.,
num_workers: int = 1,
prefetch_buffer_size: int = 1,
seed: int = 42,
Expand Down Expand Up @@ -139,6 +173,9 @@ def get_dataloader(
ProcessEpisodeAndSlice(
seq_len=seq_len, image_h=image_h, image_w=image_w, image_c=image_c
),
DarknessFilter(
darkness_threshold=darkness_threshold
),
grain.transforms.Batch(batch_size=per_process_batch_size, drop_remainder=True),
]
Copy link

Copilot AI Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DarknessFilter is always instantiated even when darkness_threshold is 0.0 (disabled). Consider conditionally adding the filter only when darkness_threshold > 0.0 to avoid unnecessary processing overhead.

Suggested change
]
]
if darkness_threshold > 0.0:
operations.append(
DarknessFilter(
darkness_threshold=darkness_threshold
)
)
operations.append(
grain.transforms.Batch(batch_size=per_process_batch_size, drop_remainder=True)
)

Copilot uses AI. Check for mistakes.

Expand Down