Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a566fd6
Port spatiotemporal dataset
sgreenbury Nov 28, 2025
134d7f8
Add dep
sgreenbury Nov 28, 2025
fc3d46e
Rename class
sgreenbury Nov 28, 2025
4cef570
Add example FNO
sgreenbury Nov 28, 2025
971a0e4
Add exploratory notebook
sgreenbury Nov 28, 2025
9d3ec0c
Fix FNO params
sgreenbury Nov 28, 2025
7259e0f
Merge remote-tracking branch 'origin/main' into add-datasets-with-main
sgreenbury Dec 1, 2025
bb45595
Add permute_concat encoder, add shapes
sgreenbury Dec 1, 2025
c1ec459
Update dataset to return batch
sgreenbury Dec 1, 2025
9e1fd7d
Add ChannelsLast decoder
sgreenbury Dec 1, 2025
098a07c
Update EncoderDecoder and EncoderProcessorDecoder
sgreenbury Dec 1, 2025
95448b6
Refactor FNO to module
sgreenbury Dec 1, 2025
15afb56
Add collate_batches for dataclass
sgreenbury Dec 1, 2025
5cf0acc
Update factory
sgreenbury Dec 1, 2025
7961c7c
Fix init
sgreenbury Dec 1, 2025
8662199
Extend EncoderProcessorDecoder
sgreenbury Dec 1, 2025
be6931b
Update exploratory notebook
sgreenbury Dec 1, 2025
d983224
Fix lint
sgreenbury Dec 1, 2025
48baa5c
Update rollout
sgreenbury Dec 1, 2025
3618144
Add tests
sgreenbury Dec 1, 2025
21329ef
Add BatchMixin, add WellDataset
sgreenbury Dec 1, 2025
0697dcd
Merge branch '14-add-datasets' into 34-fno
sgreenbury Dec 1, 2025
e4ab982
Update function signatures
sgreenbury Dec 2, 2025
b7af0aa
Fix test processor
sgreenbury Dec 2, 2025
32f08fb
Add ABC to base, fix FNO
sgreenbury Dec 2, 2025
87993ee
Rename
sgreenbury Dec 2, 2025
3a91f7c
Merge branch '14-add-datasets' into 34-fno
sgreenbury Dec 2, 2025
e5c9005
Add test_fno, add conftest for shared fixtures
sgreenbury Dec 2, 2025
b1fcbb8
Fix FNO init to ensure loss_func
sgreenbury Dec 2, 2025
ae06d03
Fix or suppress test warnings, add default optimizer
sgreenbury Dec 2, 2025
1e9b169
Merge remote-tracking branch 'origin/main' into 34-fno
sgreenbury Dec 4, 2025
16d6333
Remove init from base encoder and decoder
sgreenbury Dec 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions notebooks/00_exploration.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"## AutoCast encoder-processor-decoder model API Exploration\n",
"\n",
"This notebook aims to explore the end-to-end API.\n"
]
},
{
"cell_type": "markdown",
"id": "1",
"metadata": {},
"source": [
"### Example dataaset\n",
"\n",
"We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from autoemulate.simulations.advection_diffusion import AdvectionDiffusion\n",
"\n",
"sim = AdvectionDiffusion(return_timeseries=True, log_level=\"error\")\n",
"\n",
"def generate_split(\n",
" simulator: AdvectionDiffusion, n_train: int = 4, n_valid: int = 2, n_test: int = 2\n",
"):\n",
" \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n",
" train = simulator.forward_samples_spatiotemporal(n_train)\n",
" valid = simulator.forward_samples_spatiotemporal(n_valid)\n",
" test = simulator.forward_samples_spatiotemporal(n_test)\n",
" return {\"train\": train, \"valid\": valid, \"test\": test}\n",
"\n",
"\n",
"combined_data = generate_split(sim)"
]
},
{
"cell_type": "markdown",
"id": "3",
"metadata": {},
"source": [
"### Read combined data into datamodule\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"from auto_cast.data.datamodule import SpatioTemporalDataModule\n",
"\n",
"datamodule = SpatioTemporalDataModule(\n",
" data=combined_data, data_path=None, n_steps_input=4, n_steps_output=1, batch_size=16\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5",
"metadata": {},
"source": [
"### Example batch\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"batch = next(iter(datamodule.train_dataloader()))\n",
"\n",
"# batch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"from auto_cast.decoders.channels_last import ChannelsLast\n",
"from auto_cast.encoders.permute_concat import PermuteConcat\n",
"from auto_cast.models.encoder_decoder import EncoderDecoder\n",
"from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder\n",
"from auto_cast.nn.fno import FNOProcessor\n",
"\n",
"processor = FNOProcessor(\n",
" in_channels=1, out_channels=1, n_modes=(16, 16, 1), hidden_channels=64\n",
")\n",
"encoder = PermuteConcat(with_constants=False)\n",
"decoder = ChannelsLast()\n",
"\n",
"model = EncoderProcessorDecoder.from_encoder_processor_decoder(\n",
" encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),\n",
" processor=processor,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "8",
"metadata": {},
"source": [
"### Run trainer\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"import lightning as L\n",
"\n",
"device = \"mps\" # \"cpu\"\n",
"trainer = L.Trainer(max_epochs=5, accelerator=device, log_every_n_steps=10)\n",
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())"
]
},
{
"cell_type": "markdown",
"id": "10",
"metadata": {},
"source": [
"### Run the evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": [
"trainer.test(model, datamodule.test_dataloader())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ authors = [
]
requires-python = ">=3.11,<3.13"
dependencies = [
"autoemulate>=1.2.0",
"einops>=0.8.1",
"h5py>=3.15.1",
"lightning>=2.5.6",
"neuraloperator>=2.0.0",
"the-well>=1.1.0",
"torch>=2.9.1",
]
Expand Down Expand Up @@ -91,3 +93,14 @@ convention = "numpy"

[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["D"]

[tool.uv.sources]
autoemulate = { git = "https://github.com/alan-turing-institute/autoemulate.git" }

[tool.pytest.ini_options]
filterwarnings = [
# Ignore Lightning warnings that are expected/benign in test environment
"ignore:You are trying to `self.log\\(\\)` but the `self.trainer` reference is not registered:UserWarning",
"ignore:GPU available but not used:UserWarning",
"ignore:The '.*_dataloader' does not have many workers:UserWarning",
]
5 changes: 0 additions & 5 deletions src/auto_cast/decoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
class Decoder(nn.Module, ABC):
"""Base Decoder."""

def __init__(self, latent_dim: int, output_channels: int) -> None:
super().__init__()
self.latent_dim = latent_dim
self.output_channels = output_channels

def decode(self, z: Tensor) -> Tensor:
"""Decode the latent tensor back to the original space.

Expand Down
12 changes: 12 additions & 0 deletions src/auto_cast/decoders/channels_last.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from einops import rearrange

from auto_cast.decoders.base import Decoder
from auto_cast.types import Tensor


class ChannelsLast(Decoder):
"""Base Decoder."""

def forward(self, x: Tensor) -> Tensor:
"""Forward pass through the ChannelsLast decoder."""
return rearrange(x, "b c t w h -> b t w h c")
5 changes: 0 additions & 5 deletions src/auto_cast/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
class Encoder(nn.Module, ABC):
"""Base encoder."""

def __init__(self, latent_dim: int, input_channels: int) -> None:
super().__init__()
self.latent_dim = latent_dim
self.input_channels = input_channels

def encode(self, x: Tensor) -> Tensor:
"""Encode the input tensor into the latent space.

Expand Down
29 changes: 29 additions & 0 deletions src/auto_cast/encoders/permute_concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from einops import rearrange

from auto_cast.encoders.base import Encoder
from auto_cast.types import Batch, Tensor


class PermuteConcat(Encoder):
"""Permute and concatenate Encoder."""

def __init__(self, with_constants: bool = False) -> None:
super().__init__()
self.with_constants = with_constants

def forward(self, batch: Batch) -> Tensor:
# Destructure batch, time, space, channels
b, t, w, h, _ = batch.input_fields.shape # TODO: generalize beyond 2D spatial
x = batch.input_fields
x = rearrange(x, "b t w h c -> b c t w h")
if self.with_constants and batch.constant_fields is not None:
constants = batch.constant_fields
constants = rearrange(constants, "b w h c -> b c 1 w h")
x = torch.cat([x, constants], dim=1)
if self.with_constants and batch.constant_scalars is not None:
scalars = batch.constant_scalars
scalars = rearrange(scalars, "b c -> b c 1 1 1")
scalars = scalars.expand(b, -1, t, w, h)
x = torch.cat([x, scalars], dim=1)
return x
27 changes: 17 additions & 10 deletions src/auto_cast/models/encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any

import lightning as L
import torch
from torch import nn
Expand All @@ -14,16 +12,25 @@ class EncoderDecoder(L.LightningModule):

encoder: Encoder
decoder: Decoder
loss_func: nn.Module
loss_func: nn.Module | None

def __init__(self):
pass
def __init__(
self, encoder: Encoder, decoder: Decoder, loss_func: nn.Module | None = None
) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.loss_func = loss_func

def forward(self, *args: Any, **kwargs: Any) -> Any:
return self.decoder(self.encoder(*args, **kwargs))
def forward(self, batch: Batch) -> Tensor:
return self.decoder(self.encoder(batch))

def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
output = self.encode(batch)
if self.loss_func is None:
msg = "Loss function not defined for EncoderDecoder model."
raise ValueError(msg)
x = self.encode(batch)
output = self.decoder(x)
loss = self.loss_func(output, batch.output_fields)
return loss # noqa: RET504

Expand All @@ -43,8 +50,8 @@ def configure_optmizers(self):
class VAE(EncoderDecoder):
"""Variational Autoencoder Model."""

def forward(self, x: Tensor) -> Tensor:
mu, log_var = self.encoder(x)
def forward(self, batch: Batch) -> Tensor:
mu, log_var = self.encoder(batch)
z = self.reparametrize(mu, log_var)
x = self.decoder(z)
return x # noqa: RET504
Expand Down
Loading