This repository contains the PyTorch implementation featuring code for data generation, model training, and evaluation of our paper.
📄 Minimal Convolutional RNNs Accelerate Spatiotemporal Learning
Coşku Can Horuz, Sebastian Otte, Martin V. Butz, Matthias Karlbauer
Adaptive AI Lab, University of Lübeck
Neuro-Cognitive Modeling Group, University of Tübingen
Our work introduces efficient, minimal and parallel scan compatible convolutional RNN architectures for spatiotemporal prediction tasks. These lightweight models demonstrate strong performance on synthetic and real-world datasets while being significantly faster and simpler than existing baselines.
To install the mcrnn (Minimal Convolutional RNN) package:
-
Create and activate a new environment:
conda create -n mcrnn python=3.11 -y && conda activate mcrnn -
Change into the
mcrnndirectory:cd src/mcrnn -
Install the package:
pip install -e .
⚠️ Data generation requires PyTorch v1.6 — we recommend using a separate environment to avoid version conflicts.
-
Create a separate environment:
conda create -n nsgen python=3.7 -y && conda activate nsgen pip install torch==1.6.0 "xarray[parallel]" scipy einops tqdm matplotlib
-
Generate data:
python data/data_generation/navier-stokes/generate_ns_2d.py
-
Data geneneration can take long, depending on the number of samples that are generated and the simulation configuration. A small data set for experimentation with 5 samples and a simulation time of 50 can be created as follows
python data/data_generation/navier-stokes/generate_ns_2d.py -n 5 -t 50 -b 1
To create train/val/test sets:
python data/data_generation/navier-stokes/generate_ns_2d.py -n 1000 -t 50 # train
python data/data_generation/navier-stokes/generate_ns_2d.py -n 50 -t 50 # val
python data/data_generation/navier-stokes/generate_ns_2d.py -n 200 -t 51 # testTip
Visualize a single sample in an animation
python data/data_generation/navier-stokes/generate_ns_2d.py --animate -n 1 -b 1 -r 32As a real-world case, we consider the air pressure field at 500hPa as a key component of our planet's weather system. WeatherBench offers various atmospheric variables on different spatial resolutions. We use [1, 32, 64] in [C, H, W].
Run the following commands to download the data (compressed ~1.4GB) and extract it to the appropriate directory in the repository.
mkdir -p data/netcdf/ rsync -P rsync://m1524895@dataserv.ub.tum.de/m1524895/5.625deg/geopotential_500/geopotential_500_5.625deg.zip data/netcdf/geopotential_500_5.625deg.zipenter password
m1524895unzip data/netcdf/geopotential_500_5.625deg.zip -d data/netcdf/geopotential_500_5.625deg rm data/netcdf/geopotential_500_5.625deg.zip
Activate the environment:
conda activate mcrnnTrain a model (e.g. 2-layer ConvLSTM with 4 cells each, CPU):
python scripts/train.py model=convlstm model.hidden_sizes=[4,4] model.name=my_clstm4-4_model data=navier-stokes device="cpu"Train on GPU with default 16 cells:
python scripts/train.py model=convlstm model.name=my_clstm16-16_model data=navier-stokes device=cuda:0More examples and all training commands used in this study are listed in the bash files:
Tip
Use TensorBoard to inspect training:
tensorboard --logdir outputs
# Open http://localhost:6006/To evaluate a trained model, run
python scripts/evaluate.py -c outputs/my_clstm4-4_modelThe correct model.name must be provided as -c argument (standing for "checkpoint"). The evaluation script will compute an RMSE and animate the predicted dynamics next to the ground truth.
Multi- and cross-model evaluations can be performed by passing multiple model names, e.g.,
python scripts/evaluate.py -c outputs/my_clstm4-4_model outputs/my_clstm16-16_modelWildcards can be used to indicate a family of models by name, e.g.,
python scripts/evaluate.py -c outputs/*clstm*evaluates all models in the outputs directory that have clstm in their name.
If you find this repository helpful, please cite:
@InProceedings{horuz2025minimal,
author="Horuz, Co{\c{s}}ku Can and Otte, Sebastian and Butz, Martin V. and Karlbauer, Matthias",
title="Minimal Convolutional RNNs Accelerate Spatiotemporal Learning",
booktitle="Artificial Neural Networks and Machine Learning -- ICANN 2025",
year="2025",
publisher="Springer Nature Switzerland",
address="Cham",
pages="545--557",
isbn="978-3-032-04558-4"
}This project is licensed under the MIT License.