Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7c4321c
almost there
BernhardAhrens Dec 23, 2025
ce5d23c
up to compute_loss
BernhardAhrens Dec 23, 2025
35b91ca
align by name
BernhardAhrens Dec 25, 2025
49731c0
row vs col conflict
BernhardAhrens Dec 25, 2025
a3ffec0
works
BernhardAhrens Dec 26, 2025
196a677
make mlp and lstm run
BernhardAhrens Dec 26, 2025
bae8db3
col = inout (features, forcing, targets, output), row = batch_size, w…
BernhardAhrens Dec 26, 2025
6090122
toArray function uses inout
BernhardAhrens Dec 27, 2025
3c38f02
towards switch for DimArrays
BernhardAhrens Dec 27, 2025
63f4ecf
format runic
BernhardAhrens Dec 27, 2025
f3555be
DimensionalData cannot do backprop thru at view
BernhardAhrens Dec 27, 2025
56b4b2f
put output 3D 2D and DimArray and KeyedArray into Dataframe
BernhardAhrens Dec 27, 2025
c917a33
construction of broadcast layer as RecurrenceOutputDense
BernhardAhrens Dec 27, 2025
e0d1665
Dev example as Literate.jl script
BernhardAhrens Dec 27, 2025
baa7170
Update docs/literate/tutorials/example_synthetic_lstm.jl
BernhardAhrens Dec 27, 2025
455360f
prints
BernhardAhrens Dec 27, 2025
92792eb
print short
BernhardAhrens Dec 27, 2025
a0fb475
Merge branch 'ba/lstm' of https://github.com/EarthyScience/EasyHybrid…
BernhardAhrens Dec 27, 2025
2fd596d
update tutorials
BernhardAhrens Dec 28, 2025
740483c
KeyedArray default
BernhardAhrens Dec 28, 2025
d7f0c48
correct dispatch
BernhardAhrens Dec 28, 2025
20770ec
prepare_data function uses KeyedArray as default
BernhardAhrens Jan 16, 2026
1b3e203
fix tests
BernhardAhrens Jan 16, 2026
d62cdf8
correct matrices
BernhardAhrens Jan 19, 2026
4071a4f
run format.jl
BernhardAhrens Jan 19, 2026
7f16fae
get test running
BernhardAhrens Jan 19, 2026
aff68de
runic macro conflict
BernhardAhrens Jan 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EasyHybrid"
uuid = "61bb816a-e6af-4913-ab9e-91bff2e122e3"
authors = ["Lazaro Alonso", "Bernhard Ahrens", "Markus Reichstein"]
version = "0.1.7"
authors = ["Lazaro Alonso", "Bernhard Ahrens", "Markus Reichstein"]

[deps]
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
Expand All @@ -21,6 +21,7 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down Expand Up @@ -57,6 +58,7 @@ MLJ = "0.20, 0.21, 0.22"
MLUtils = "0.4.8"
Makie = "0.22, 0.23, 0.24"
NCDatasets = "0.14.8"
NamedDims = "1.2.3"
OptimizationOptimisers = "0.3.7"
PrettyTables = "2.4.0, 3.1.2"
ProgressMeter = "1.10.4"
Expand Down
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
[deps]
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"

[sources]
Expand Down
191 changes: 191 additions & 0 deletions docs/literate/tutorials/example_synthetic_lstm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# # LSTM Hybrid Model with EasyHybrid.jl
#
# This tutorial demonstrates how to use EasyHybrid to train a hybrid model with LSTM
# neural networks on synthetic data for respiration modeling with Q10 temperature sensitivity.
# The code for this tutorial can be found in [docs/src/literate/tutorials](https://github.com/EarthyScience/EasyHybrid.jl/tree/main/docs/src/literate/tutorials/) => example_synthetic_lstm.jl.
#
# ## 1. Load Packages

# Set project path and activate environment
#using Pkg
#project_path = "docs"
#Pkg.activate(project_path)
#EasyHybrid_path = joinpath(pwd())
#Pkg.develop(path = EasyHybrid_path)
#Pkg.resolve()
#Pkg.instantiate()

using EasyHybrid
using AxisKeys
using DimensionalData
using Lux

# ## 2. Data Loading and Preprocessing

# Load synthetic dataset from GitHub
df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc");

# Select a subset of data for faster execution
df = df[1:20000, :];
first(df, 5);

# ## 3. Define Neural Network Architectures

# Define a standard feedforward neural network
NN = Chain(Dense(15, 15, Lux.sigmoid), Dense(15, 15, Lux.sigmoid), Dense(15, 1))

# Define LSTM-based neural network with memory
# Note: When the Chain ends with a Recurrence layer, EasyHybrid automatically adds
# a RecurrenceOutputDense layer to handle the sequence output format.
# The user only needs to define the Recurrence layer itself!
NN_Memory = Chain(
Recurrence(LSTMCell(15 => 15), return_sequence = true),
)

# ## 4. Define the Physical Model

"""
RbQ10(; ta, Q10, rb, tref=15.0f0)

Respiration model with Q10 temperature sensitivity.

- `ta`: air temperature [°C]
- `Q10`: temperature sensitivity factor [-]
- `rb`: basal respiration rate [μmol/m²/s]
- `tref`: reference temperature [°C] (default: 15.0)
"""
function RbQ10(; ta, Q10, rb, tref = 15.0f0)
reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref))
return (; reco, Q10, rb)
end

# ## 5. Define Model Parameters

# Parameter specification: (default, lower_bound, upper_bound)
parameters = (
rb = (3.0f0, 0.0f0, 13.0f0), # Basal respiration [μmol/m²/s]
Q10 = (2.0f0, 1.0f0, 4.0f0), # Temperature sensitivity factor [-]
)

# ## 6. Configure Hybrid Model Components

# Define input variables
# Forcing variables (temperature)
forcing = [:ta]
# Predictor variables (solar radiation, and its derivative)
predictors = [:sw_pot, :dsw_pot]
# Target variable (respiration)
target = [:reco]

# Parameter classification
# Global parameters (same for all samples)
global_param_names = [:Q10]
# Neural network predicted parameters
neural_param_names = [:rb]

# ## 7. Construct LSTM Hybrid Model

# Create LSTM hybrid model using the unified constructor
hlstm = constructHybridModel(
predictors,
forcing,
target,
RbQ10,
parameters,
neural_param_names,
global_param_names,
hidden_layers = NN_Memory, # Neural network architecture
scale_nn_outputs = true, # Scale neural network outputs
input_batchnorm = false # Apply batch normalization to inputs
)

# ## 8. Data Preparation Steps (Demonstration)

# The following steps demonstrate what happens under the hood during training.
# In practice, you can skip to Section 9 and use the `train` function directly.

# :KeyedArray and :DimArray are supported
x, y = prepare_data(hlstm, df, array_type = :DimArray);

# New split_into_sequences with input_window, output_window, shift and lead_time
# for many-to-one, many-to-many, and different prediction lead times and overlap
xs, ys = split_into_sequences(x, y; input_window = 20, output_window = 2, shift = 1, lead_time = 0);
ys_nan = .!isnan.(ys);

# Split data as in train
sdf = split_data(df, hlstm, sequence_kwargs = (; input_window = 10, output_window = 3, shift = 1, lead_time = 1));

typeof(sdf);
(x_train, y_train), (x_val, y_val) = sdf;
x_train;
y_train;
y_train_nan = .!isnan.(y_train);

# Put into train loader to compose minibatches
train_dl = EasyHybrid.DataLoader((x_train, y_train); batchsize = 32);

# Run hybrid model forwards
x_first = first(train_dl)[1];
y_first = first(train_dl)[2];

ps, st = Lux.setup(Random.default_rng(), hlstm);
frun = hlstm(x_first, ps, st);

# Extract predicted yhat
reco_mod = frun[1].reco;

# Bring observations in same shape
reco_obs = dropdims(y_first, dims = 1);
reco_nan = .!isnan.(reco_obs);

# Compute loss
EasyHybrid.compute_loss(hlstm, ps, st, (x_train, (y_train, y_train_nan)), logging = LoggingLoss(train_mode = true));

# ## 9. Train LSTM Hybrid Model

out_lstm = train(
hlstm,
df,
();
nepochs = 2, # Number of training epochs
batchsize = 512, # Batch size for training
opt = AdamW(0.1), # Optimizer and learning rate
monitor_names = [:rb, :Q10], # Parameters to monitor during training
yscale = identity, # Scaling for outputs
shuffleobs = false,
loss_types = [:mse, :nse],
sequence_kwargs = (; input_window = 10, output_window = 4),
plotting = false,
array_type = :DimArray
);

# ## 10. Train Single NN Hybrid Model (Optional)

# For comparison, we can also train a hybrid model with a standard feedforward neural network
hm = constructHybridModel(
predictors,
forcing,
target,
RbQ10,
parameters,
neural_param_names,
global_param_names,
hidden_layers = NN, # Neural network architecture
scale_nn_outputs = true, # Scale neural network outputs
input_batchnorm = false, # Apply batch normalization to inputs
)

# Train the hybrid model
single_nn_out = train(
hm,
df,
();
nepochs = 3, # Number of training epochs
batchsize = 512, # Batch size for training
opt = AdamW(0.1), # Optimizer and learning rate
monitor_names = [:rb, :Q10], # Parameters to monitor during training
yscale = identity, # Scaling for outputs
shuffleobs = false,
loss_types = [:mse, :nse],
array_type = :DimArray
);
6 changes: 6 additions & 0 deletions docs/literate/tutorials/folds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
#
# ## 1. Load Packages

#using Pkg
#project_path = "docs"
#Pkg.activate(project_path)
#EasyHybrid_path = joinpath(pwd())
#Pkg.develop(path = EasyHybrid_path)

using EasyHybrid
using OhMyThreads
using CairoMakie
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ makedocs(;
"Hyperparameter Tuning" => "tutorials/hyperparameter_tuning.md",
"Slurm" => "tutorials/slurm.md",
"Cross-validation" => "tutorials/folds.md",
"LSTM Hybrid Model" => "tutorials/example_synthetic_lstm.md",
"Loss Functions" => "tutorials/losses.md",
],
"Research" => [
Expand Down
Loading
Loading