diff --git a/src/auto_cast/data/datamodule.py b/src/auto_cast/data/datamodule.py new file mode 100644 index 0000000..8b00508 --- /dev/null +++ b/src/auto_cast/data/datamodule.py @@ -0,0 +1,175 @@ +from pathlib import Path + +import torch +from the_well.data.datamodule import WellDataModule +from the_well.data.normalization import ZScoreNormalization +from torch.utils.data import DataLoader + +from auto_cast.data.dataset import SpatioTemporalDataset +from auto_cast.types import collate_batches + + +class SpatioTemporalDataModule(WellDataModule): + """A class for spatio-temporal data modules.""" + + def __init__( + self, + data_path: str | None, + data: dict[str, dict] | None = None, + dataset_cls: type[SpatioTemporalDataset] = SpatioTemporalDataset, + n_steps_input: int = 1, + n_steps_output: int = 1, + stride: int = 1, + # TODO: support for passing data from dict + input_channel_idxs: tuple[int, ...] | None = None, + output_channel_idxs: tuple[int, ...] | None = None, + batch_size: int = 4, + dtype: torch.dtype = torch.float32, + ftype: str = "torch", + verbose: bool = False, + use_normalization: bool = False, + ): + self.verbose = verbose + self.use_normalization = use_normalization + + base_path = Path(data_path) if data_path is not None else None + suffix = ".pt" if ftype == "torch" else ".h5" + fname = f"data{suffix}" + train_path = base_path / "train" / fname if base_path is not None else None + valid_path = base_path / "valid" / fname if base_path is not None else None + test_path = base_path / "test" / fname if base_path is not None else None + + # Create training dataset first (without normalization) + self.train_dataset = dataset_cls( + data_path=str(train_path) if train_path is not None else None, + data=data["train"] if data is not None else None, + n_steps_input=n_steps_input, + n_steps_output=n_steps_output, + stride=stride, + input_channel_idxs=input_channel_idxs, + output_channel_idxs=output_channel_idxs, + dtype=dtype, + verbose=self.verbose, + use_normalization=False, # Temporarily disable to compute stats + norm=None, + ) + + # Compute normalization from training data if requested + norm = None + if self.use_normalization: + if self.verbose: + print("Computing normalization statistics from training data...") + norm = ZScoreNormalization + # if self.verbose: + # print(f" Mean (per channel): {norm.mean}") + # print(f" Std (per channel): {norm.std}") + + # Now enable normalization for training dataset + self.train_dataset.use_normalization = True + self.train_dataset.norm = norm + + self.val_dataset = dataset_cls( + data_path=str(valid_path) if valid_path is not None else None, + data=data["valid"] if data is not None else None, + n_steps_input=n_steps_input, + n_steps_output=n_steps_output, + stride=stride, + input_channel_idxs=input_channel_idxs, + output_channel_idxs=output_channel_idxs, + dtype=dtype, + verbose=self.verbose, + use_normalization=self.use_normalization, + norm=norm, + ) + self.test_dataset = dataset_cls( + data_path=str(test_path) if test_path is not None else None, + data=data["test"] if data is not None else None, + n_steps_input=n_steps_input, + n_steps_output=n_steps_output, + stride=stride, + input_channel_idxs=input_channel_idxs, + output_channel_idxs=output_channel_idxs, + dtype=dtype, + verbose=self.verbose, + use_normalization=self.use_normalization, + norm=norm, + ) + self.rollout_val_dataset = dataset_cls( + data_path=str(train_path) if train_path is not None else None, + data=data["train"] if data is not None else None, + n_steps_input=n_steps_input, + n_steps_output=n_steps_output, + stride=stride, + input_channel_idxs=input_channel_idxs, + output_channel_idxs=output_channel_idxs, + full_trajectory_mode=True, + dtype=dtype, + verbose=self.verbose, + use_normalization=self.use_normalization, + norm=norm, + ) + self.rollout_test_dataset = dataset_cls( + data_path=str(test_path) if test_path is not None else None, + data=data["test"] if data is not None else None, + n_steps_input=n_steps_input, + n_steps_output=n_steps_output, + stride=stride, + input_channel_idxs=input_channel_idxs, + output_channel_idxs=output_channel_idxs, + full_trajectory_mode=True, + dtype=dtype, + verbose=self.verbose, + use_normalization=self.use_normalization, + norm=norm, + ) + self.batch_size = batch_size + + def train_dataloader(self) -> DataLoader: + """DataLoader for training.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=1, + collate_fn=collate_batches, + ) + + def val_dataloader(self) -> DataLoader: + """DataLoader for standard validation (not full trajectory rollouts).""" + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=1, + collate_fn=collate_batches, + ) + + def rollout_val_dataloader(self) -> DataLoader: + """DataLoader for full trajectory rollouts on validation data.""" + return DataLoader( + self.rollout_val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=1, + collate_fn=collate_batches, + ) + + def test_dataloader(self) -> DataLoader: + """DataLoader for testing.""" + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=1, + collate_fn=collate_batches, + ) + + def rollout_test_dataloader(self) -> DataLoader: + """DataLoader for full trajectory rollouts on test data.""" + return DataLoader( + self.rollout_test_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=1, + collate_fn=collate_batches, + ) diff --git a/src/auto_cast/data/dataset.py b/src/auto_cast/data/dataset.py new file mode 100644 index 0000000..df7cdff --- /dev/null +++ b/src/auto_cast/data/dataset.py @@ -0,0 +1,362 @@ +from collections.abc import Callable +from typing import Any, Literal + +import h5py +import torch +from the_well.data import Augmentation, WellDataset +from the_well.data.normalization import ZScoreNormalization +from torch.utils.data import Dataset + +from auto_cast.data.metadata import Metadata +from auto_cast.types import Batch + + +class BatchMixin: + """A mixin class to provide Batch conversion functionality.""" + + @staticmethod + def to_batch(data: dict) -> Batch: + """Convert a dictionary of tensors to a Batch object.""" + return Batch( + input_fields=data["input_fields"], + output_fields=data["output_fields"], + constant_scalars=data.get("constant_scalars"), + constant_fields=data.get("constant_fields"), + ) + + +class SpatioTemporalDataset(Dataset, BatchMixin): + """A class for spatio-temporal datasets.""" + + def __init__( + self, + data_path: str | None, + data: dict | None = None, + n_steps_input: int = 1, + n_steps_output: int = 1, + stride: int = 1, + # TODO: support for passing data from dict + input_channel_idxs: tuple[int, ...] | None = None, + output_channel_idxs: tuple[int, ...] | None = None, + full_trajectory_mode: bool = False, + dtype: torch.dtype = torch.float32, + verbose: bool = False, + use_normalization: bool = False, + norm: type[ZScoreNormalization] | None = None, + ): + """ + Initialize the dataset. + + Parameters + ---------- + data_path: str + Path to the HDF5 file containing the dataset. + n_steps_input: int + Number of input time steps. + n_steps_output: int + Number of output time steps. + stride: int + Stride for sampling the data. + data: dict | None + Preloaded data. Defaults to None. + input_channel_idxs: tuple[int, ...] | None + Indices of input channels to use. Defaults to None. + output_channel_idxs: tuple[int, ...] | None + Indices of output channels to use. Defaults to None. + full_trajectory_mode: bool + If True, use full trajectories without creating subtrajectories. + dtype: torch.dtype + Data type for tensors. Defaults to torch.float32. + verbose: bool + If True, print dataset information. + use_normalization: bool + Whether to apply Z-score normalization. Defaults to False. + norm: type[Standardizer] | None + Normalization object (computed from training data). Defaults to None. + """ + self.dtype = dtype + self.verbose = verbose + self.use_normalization = use_normalization + self.norm = norm + + # Read or parse data + self.read_data(data_path) if data_path is not None else self.parse_data(data) + + self.full_trajectory_mode = full_trajectory_mode + self.n_steps_input = n_steps_input + self.n_steps_output = ( + n_steps_output + if not self.full_trajectory_mode + # TODO: make more robust and flexible for different trajectory lengths + else self.data.shape[1] - self.n_steps_input + ) + self.stride = stride + self.input_channel_idxs = input_channel_idxs + self.output_channel_idxs = output_channel_idxs + + # Destructured here + ( + self.n_trajectories, + self.n_timesteps, + self.width, + self.height, + self.n_channels, + ) = self.data.shape + + # Pre-compute all subtrajectories for efficient indexing + self.all_input_fields = [] + self.all_output_fields = [] + self.all_constant_scalars = [] + self.all_constant_fields = [] + + for traj_idx in range(self.n_trajectories): + # Create subtrajectories for this trajectory + fields = ( + self.data[traj_idx] + .unfold(0, self.n_steps_input + self.n_steps_output, self.stride) + .permute(0, -1, 1, 2, 3) # [num_subtrajectories, T_in + T_out, W, H, C] + ) + + # Split into input and output + input_fields = fields[ + :, : self.n_steps_input, ... + ] # [num_subtrajectories, T_in, W, H, C] + output_fields = fields[ + :, self.n_steps_input :, ... + ] # [num_subtrajectories, T_out, W, H, C] + + # Store each subtrajectory separately + for sub_idx in range(input_fields.shape[0]): + self.all_input_fields.append( + input_fields[sub_idx].to(self.dtype) + ) # [T_in, W, H, C] + self.all_output_fields.append( + output_fields[sub_idx].to(self.dtype) + ) # [T_out, W, H, C] + + # Handle constant scalars + if self.constant_scalars is not None: + self.all_constant_scalars.append( + self.constant_scalars[traj_idx].to(self.dtype) + ) + + # Handle constant fields + if self.constant_fields is not None: + self.all_constant_fields.append( + self.constant_fields[traj_idx].to(self.dtype) + ) + + if self.verbose: + print(f"Created {len(self.all_input_fields)} subtrajectory samples") + print(f"Each input sample shape: {self.all_input_fields[0].shape}") + print(f"Each output sample shape: {self.all_output_fields[0].shape}") + print(f"Data type: {self.all_input_fields[0].dtype}") + + def _from_f(self, f): + assert "data" in f, "HDF5 file must contain 'data' dataset" + self.data = torch.Tensor(f["data"][:]).to(self.dtype) # type: ignore # [N, T, W, H, C] # noqa: PGH003 + if self.verbose: + print(f"Loaded data shape: {self.data.shape}") + # TODO: add the constant scalars + self.constant_scalars = ( + torch.Tensor(f["constant_scalars"][:]).to(self.dtype) # type: ignore # noqa: PGH003 + if "constant_scalars" in f + else None + ) # [N, C] + + # Constant fields + self.constant_fields = ( + torch.Tensor(f["constant_fields"][:]).to( # type: ignore # noqa: PGH003 + self.dtype + ) # [N, W, H, C] + if "constant_fields" in f and f["constant_fields"] != {} + else None + ) + + def read_data(self, data_path: str): + """Read data. + + By default assumes HDF5 format in `data_path` with correct shape and fields. + """ + self.data_path = data_path + if self.data_path.endswith(".h5") or self.data_path.endswith(".hdf5"): + with h5py.File(self.data_path, "r") as f: + self._from_f(f) + if self.data_path.endswith(".pt"): + self._from_f(torch.load(self.data_path)) + + def parse_data(self, data: dict | None): + """Parse data from a dictionary.""" + if data is not None: + self.data = ( + data["data"].to(self.dtype) + if torch.is_tensor(data["data"]) + else torch.tensor(data["data"], dtype=self.dtype) + ) + self.constant_scalars = data.get("constant_scalars", None) + self.constant_fields = data.get("constant_fields", None) + return + msg = "No data provided to parse." + raise ValueError(msg) + + def __len__(self): # noqa: D105 + return len(self.all_input_fields) + + def __getitem__(self, idx): # noqa: D105 + input_fields = self.all_input_fields[idx] + output_fields = self.all_output_fields[idx] + + item = { + "input_fields": input_fields, + "output_fields": output_fields, + } + if len(self.all_constant_scalars) > 0: + item["constant_scalars"] = self.all_constant_scalars[idx] + if len(self.all_constant_fields) > 0: + item["constant_fields"] = self.all_constant_fields[idx] + + return self.to_batch(item) + + +class ReactionDiffusionDataset(SpatioTemporalDataset): + """Reaction-Diffusion dataset.""" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.metadata = Metadata( + dataset_name="ReactionDiffusion", + n_spatial_dims=2, + spatial_resolution=self.data.shape[-3:-1], + scalar_names=[], + constant_scalar_names=["beta", "d"], + field_names={0: ["U", "V"]}, + constant_field_names={}, + boundary_condition_types=["periodic", "periodic"], + n_files=0, + n_trajectories_per_file=[], + n_steps_per_trajectory=[self.data.shape[1]] * self.data.shape[0], + grid_type="cartesian", + ) + + +class AdvectionDiffusionDataset(SpatioTemporalDataset): + """Advection-Diffusion dataset.""" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.metadata = Metadata( + dataset_name="AdvectionDiffusion", + n_spatial_dims=2, + spatial_resolution=self.data.shape[-3:-1], + scalar_names=[], + constant_scalar_names=["nu", "mu"], + field_names={0: ["vorticity"]}, + constant_field_names={}, + boundary_condition_types=["periodic", "periodic"], + n_files=0, + n_trajectories_per_file=[], + n_steps_per_trajectory=[self.data.shape[1]] * self.data.shape[0], + grid_type="cartesian", + ) + + +class BOUTDataset(SpatioTemporalDataset): + """BOUT++ dataset.""" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.metadata = Metadata( + dataset_name="BOUT++", + n_spatial_dims=2, + spatial_resolution=self.data.shape[-3:-1], + scalar_names=[], + constant_scalar_names=[ + f"const{i}" + for i in range(self.constant_scalars.shape[-1]) # type: ignore # noqa: PGH003 + ], + field_names={0: ["vorticity"]}, + constant_field_names={}, + boundary_condition_types=["periodic", "periodic"], + n_files=0, + n_trajectories_per_file=[], + n_steps_per_trajectory=[self.data.shape[1]] * self.data.shape[0], + grid_type="cartesian", + ) + + +class TheWell(SpatioTemporalDataset): + """A wrapper around The Well's WellDataset to provide Batch objects.""" + + well_dataset: WellDataset + + def __init__( + self, + path: None | str = None, + normalization_path: str = "../stats.yaml", + well_base_path: None | str = None, + well_dataset_name: None | str = None, + well_split_name: Literal["train", "valid", "test", None] = None, + include_filters: list[str] | None = None, + exclude_filters: list[str] | None = None, + use_normalization: bool = False, + normalization_type: None | Callable[..., Any] = None, + max_rollout_steps=100, + n_steps_input: int = 1, + n_steps_output: int = 1, + min_dt_stride: int = 1, + max_dt_stride: int = 1, + flatten_tensors: bool = True, + cache_small: bool = True, + max_cache_size: float = 1e9, + return_grid: bool = True, + boundary_return_type: str = "padding", + full_trajectory_mode: bool = False, + name_override: None | str = None, + transform: None | Augmentation = None, + min_std: float = 1e-4, + storage_options: None | dict = None, + ): + exclude_filters = exclude_filters or [] + include_filters = include_filters or [] + self.well_dataset = WellDataset( + path=path, + normalization_path=normalization_path, + well_base_path=well_base_path, + well_dataset_name=well_dataset_name, + well_split_name=well_split_name, + include_filters=include_filters, + exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, + max_rollout_steps=max_rollout_steps, + n_steps_input=n_steps_input, + n_steps_output=n_steps_output, + min_dt_stride=min_dt_stride, + max_dt_stride=max_dt_stride, + flatten_tensors=flatten_tensors, + cache_small=cache_small, + max_cache_size=max_cache_size, + return_grid=return_grid, + boundary_return_type=boundary_return_type, + full_trajectory_mode=full_trajectory_mode, + name_override=name_override, + transform=transform, + min_std=min_std, + storage_options=storage_options, + ) + self.well_metadata = self.well_dataset.metadata + + def __getitem__(self, index) -> Batch: # noqa: D105 + return self.to_batch(self.well_dataset.__getitem__(index)) diff --git a/src/auto_cast/data/metadata.py b/src/auto_cast/data/metadata.py new file mode 100644 index 0000000..36eb630 --- /dev/null +++ b/src/auto_cast/data/metadata.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import Literal + + +@dataclass +class Metadata: + """Metadata for spatiotemporal datasets.""" + + dataset_name: str + n_spatial_dims: int + spatial_resolution: tuple[int, ...] + scalar_names: list[str] + constant_scalar_names: list[str] + constant_field_names: dict[str, list[str]] + boundary_condition_types: list[str] + field_names: dict[int, list[str]] + n_steps_per_trajectory: list[int] + n_files: int | None = None + n_trajectories_per_file: list[int] | None = None + grid_type: Literal["cartesian"] = "cartesian" diff --git a/src/auto_cast/types/__init__.py b/src/auto_cast/types/__init__.py index 711a8fa..7b747c0 100644 --- a/src/auto_cast/types/__init__.py +++ b/src/auto_cast/types/__init__.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from dataclasses import dataclass import torch @@ -16,10 +17,10 @@ class Batch: """A batch in input data space.""" - input_fields: Tensor - output_fields: Tensor - constant_scalars: Tensor - constant_fields: Tensor + input_fields: Tensor # (B, T, W, H, C) + output_fields: Tensor # (B, T, W, H, C) + constant_scalars: Tensor | None # (B, C) + constant_fields: Tensor | None # (B, W, H, C) @dataclass @@ -40,3 +41,31 @@ def __call__(self, batch: Batch) -> EncodedBatch: encoded_output_fields=batch.output_fields, encoded_info={}, ) + + +def collate_batches(samples: Sequence[Batch]) -> Batch: + """Stack a sequence of `Batch` instances along the batch dimension.""" + if len(samples) == 0: + msg = "collate_batches expects at least one sample" + raise ValueError(msg) + + def _stack_optional(getter: str) -> Tensor | None: + values = [getattr(sample, getter) for sample in samples] + if all(v is None for v in values): + return None + if any(v is None for v in values): + msg = f"Field '{getter}' is inconsistently None across samples" + raise ValueError(msg) + return torch.stack(values, dim=0) # type: ignore[arg-type] + + input_fields = torch.stack([sample.input_fields for sample in samples], dim=0) + output_fields = torch.stack([sample.output_fields for sample in samples], dim=0) + constant_scalars = _stack_optional("constant_scalars") + constant_fields = _stack_optional("constant_fields") + + return Batch( + input_fields=input_fields, + output_fields=output_fields, + constant_scalars=constant_scalars, + constant_fields=constant_fields, + )