diff --git a/Project.toml b/Project.toml index ececaefa..51ae98ed 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "EasyHybrid" uuid = "61bb816a-e6af-4913-ab9e-91bff2e122e3" -authors = ["Lazaro Alonso", "Bernhard Ahrens", "Markus Reichstein"] version = "0.1.7" +authors = ["Lazaro Alonso", "Bernhard Ahrens", "Markus Reichstein"] [deps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" @@ -21,6 +21,7 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab" +NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -57,6 +58,7 @@ MLJ = "0.20, 0.21, 0.22" MLUtils = "0.4.8" Makie = "0.22, 0.23, 0.24" NCDatasets = "0.14.8" +NamedDims = "1.2.3" OptimizationOptimisers = "0.3.7" PrettyTables = "2.4.0, 3.1.2" ProgressMeter = "1.10.4" diff --git a/docs/Project.toml b/docs/Project.toml index 52df2908..d06cca83 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,11 +1,15 @@ [deps] +AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc" +DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365" EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3" +GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" [sources] diff --git a/docs/literate/tutorials/example_synthetic_lstm.jl b/docs/literate/tutorials/example_synthetic_lstm.jl new file mode 100644 index 00000000..314ca6ab --- /dev/null +++ b/docs/literate/tutorials/example_synthetic_lstm.jl @@ -0,0 +1,191 @@ +# # LSTM Hybrid Model with EasyHybrid.jl +# +# This tutorial demonstrates how to use EasyHybrid to train a hybrid model with LSTM +# neural networks on synthetic data for respiration modeling with Q10 temperature sensitivity. +# The code for this tutorial can be found in [docs/src/literate/tutorials](https://github.com/EarthyScience/EasyHybrid.jl/tree/main/docs/src/literate/tutorials/) => example_synthetic_lstm.jl. +# +# ## 1. Load Packages + +# Set project path and activate environment +#using Pkg +#project_path = "docs" +#Pkg.activate(project_path) +#EasyHybrid_path = joinpath(pwd()) +#Pkg.develop(path = EasyHybrid_path) +#Pkg.resolve() +#Pkg.instantiate() + +using EasyHybrid +using AxisKeys +using DimensionalData +using Lux + +# ## 2. Data Loading and Preprocessing + +# Load synthetic dataset from GitHub +df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc"); + +# Select a subset of data for faster execution +df = df[1:20000, :]; +first(df, 5); + +# ## 3. Define Neural Network Architectures + +# Define a standard feedforward neural network +NN = Chain(Dense(15, 15, Lux.sigmoid), Dense(15, 15, Lux.sigmoid), Dense(15, 1)) + +# Define LSTM-based neural network with memory +# Note: When the Chain ends with a Recurrence layer, EasyHybrid automatically adds +# a RecurrenceOutputDense layer to handle the sequence output format. +# The user only needs to define the Recurrence layer itself! +NN_Memory = Chain( + Recurrence(LSTMCell(15 => 15), return_sequence = true), +) + +# ## 4. Define the Physical Model + +""" + RbQ10(; ta, Q10, rb, tref=15.0f0) + +Respiration model with Q10 temperature sensitivity. + +- `ta`: air temperature [°C] +- `Q10`: temperature sensitivity factor [-] +- `rb`: basal respiration rate [μmol/m²/s] +- `tref`: reference temperature [°C] (default: 15.0) +""" +function RbQ10(; ta, Q10, rb, tref = 15.0f0) + reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref)) + return (; reco, Q10, rb) +end + +# ## 5. Define Model Parameters + +# Parameter specification: (default, lower_bound, upper_bound) +parameters = ( + rb = (3.0f0, 0.0f0, 13.0f0), # Basal respiration [μmol/m²/s] + Q10 = (2.0f0, 1.0f0, 4.0f0), # Temperature sensitivity factor [-] +) + +# ## 6. Configure Hybrid Model Components + +# Define input variables +# Forcing variables (temperature) +forcing = [:ta] +# Predictor variables (solar radiation, and its derivative) +predictors = [:sw_pot, :dsw_pot] +# Target variable (respiration) +target = [:reco] + +# Parameter classification +# Global parameters (same for all samples) +global_param_names = [:Q10] +# Neural network predicted parameters +neural_param_names = [:rb] + +# ## 7. Construct LSTM Hybrid Model + +# Create LSTM hybrid model using the unified constructor +hlstm = constructHybridModel( + predictors, + forcing, + target, + RbQ10, + parameters, + neural_param_names, + global_param_names, + hidden_layers = NN_Memory, # Neural network architecture + scale_nn_outputs = true, # Scale neural network outputs + input_batchnorm = false # Apply batch normalization to inputs +) + +# ## 8. Data Preparation Steps (Demonstration) + +# The following steps demonstrate what happens under the hood during training. +# In practice, you can skip to Section 9 and use the `train` function directly. + +# :KeyedArray and :DimArray are supported +x, y = prepare_data(hlstm, df, array_type = :DimArray); + +# New split_into_sequences with input_window, output_window, shift and lead_time +# for many-to-one, many-to-many, and different prediction lead times and overlap +xs, ys = split_into_sequences(x, y; input_window = 20, output_window = 2, shift = 1, lead_time = 0); +ys_nan = .!isnan.(ys); + +# Split data as in train +sdf = split_data(df, hlstm, sequence_kwargs = (; input_window = 10, output_window = 3, shift = 1, lead_time = 1)); + +typeof(sdf); +(x_train, y_train), (x_val, y_val) = sdf; +x_train; +y_train; +y_train_nan = .!isnan.(y_train); + +# Put into train loader to compose minibatches +train_dl = EasyHybrid.DataLoader((x_train, y_train); batchsize = 32); + +# Run hybrid model forwards +x_first = first(train_dl)[1]; +y_first = first(train_dl)[2]; + +ps, st = Lux.setup(Random.default_rng(), hlstm); +frun = hlstm(x_first, ps, st); + +# Extract predicted yhat +reco_mod = frun[1].reco; + +# Bring observations in same shape +reco_obs = dropdims(y_first, dims = 1); +reco_nan = .!isnan.(reco_obs); + +# Compute loss +EasyHybrid.compute_loss(hlstm, ps, st, (x_train, (y_train, y_train_nan)), logging = LoggingLoss(train_mode = true)); + +# ## 9. Train LSTM Hybrid Model + +out_lstm = train( + hlstm, + df, + (); + nepochs = 2, # Number of training epochs + batchsize = 512, # Batch size for training + opt = AdamW(0.1), # Optimizer and learning rate + monitor_names = [:rb, :Q10], # Parameters to monitor during training + yscale = identity, # Scaling for outputs + shuffleobs = false, + loss_types = [:mse, :nse], + sequence_kwargs = (; input_window = 10, output_window = 4), + plotting = false, + array_type = :DimArray +); + +# ## 10. Train Single NN Hybrid Model (Optional) + +# For comparison, we can also train a hybrid model with a standard feedforward neural network +hm = constructHybridModel( + predictors, + forcing, + target, + RbQ10, + parameters, + neural_param_names, + global_param_names, + hidden_layers = NN, # Neural network architecture + scale_nn_outputs = true, # Scale neural network outputs + input_batchnorm = false, # Apply batch normalization to inputs +) + +# Train the hybrid model +single_nn_out = train( + hm, + df, + (); + nepochs = 3, # Number of training epochs + batchsize = 512, # Batch size for training + opt = AdamW(0.1), # Optimizer and learning rate + monitor_names = [:rb, :Q10], # Parameters to monitor during training + yscale = identity, # Scaling for outputs + shuffleobs = false, + loss_types = [:mse, :nse], + array_type = :DimArray +); diff --git a/docs/literate/tutorials/folds.jl b/docs/literate/tutorials/folds.jl index 47933d59..21d0dbe5 100644 --- a/docs/literate/tutorials/folds.jl +++ b/docs/literate/tutorials/folds.jl @@ -5,6 +5,12 @@ # # ## 1. Load Packages +#using Pkg +#project_path = "docs" +#Pkg.activate(project_path) +#EasyHybrid_path = joinpath(pwd()) +#Pkg.develop(path = EasyHybrid_path) + using EasyHybrid using OhMyThreads using CairoMakie diff --git a/docs/make.jl b/docs/make.jl index d7bd3f01..cecb28ba 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -60,6 +60,7 @@ makedocs(; "Hyperparameter Tuning" => "tutorials/hyperparameter_tuning.md", "Slurm" => "tutorials/slurm.md", "Cross-validation" => "tutorials/folds.md", + "LSTM Hybrid Model" => "tutorials/example_synthetic_lstm.md", "Loss Functions" => "tutorials/losses.md", ], "Research" => [ diff --git a/docs/src/tutorials/example_synthetic_lstm.md b/docs/src/tutorials/example_synthetic_lstm.md new file mode 100644 index 00000000..9acb26c2 --- /dev/null +++ b/docs/src/tutorials/example_synthetic_lstm.md @@ -0,0 +1,281 @@ +```@meta +EditURL = "../../literate/tutorials/example_synthetic_lstm.jl" +``` + +# LSTM Hybrid Model with EasyHybrid.jl + +This tutorial demonstrates how to use EasyHybrid to train a hybrid model with LSTM +neural networks on synthetic data for respiration modeling with Q10 temperature sensitivity. +The code for this tutorial can be found in [docs/src/literate/tutorials](https://github.com/EarthyScience/EasyHybrid.jl/tree/main/docs/src/literate/tutorials/) => example_synthetic_lstm.jl. + +## 1. Load Packages + +Set project path and activate environment + +````@example example_synthetic_lstm +using Pkg +project_path = "docs" +Pkg.activate(project_path) +EasyHybrid_path = joinpath(pwd()) +Pkg.develop(path = EasyHybrid_path) +Pkg.resolve() +Pkg.instantiate() + +using EasyHybrid +using AxisKeys +using DimensionalData +using Lux +```` + +## 2. Data Loading and Preprocessing + +Load synthetic dataset from GitHub + +````@example example_synthetic_lstm +df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc"); +nothing #hide +```` + +Select a subset of data for faster execution + +````@example example_synthetic_lstm +df = df[1:20000, :]; +first(df, 5); +nothing #hide +```` + +## 3. Define Neural Network Architectures + +Define a standard feedforward neural network + +````@example example_synthetic_lstm +NN = Chain(Dense(15, 15, Lux.sigmoid), Dense(15, 15, Lux.sigmoid), Dense(15, 1)) +```` + +Define LSTM-based neural network with memory +Note: When the Chain ends with a Recurrence layer, EasyHybrid automatically adds +a RecurrenceOutputDense layer to handle the sequence output format. +The user only needs to define the Recurrence layer itself! + +````@example example_synthetic_lstm +NN_Memory = Chain( + Recurrence(LSTMCell(15 => 15), return_sequence = true), +) +```` + +## 4. Define the Physical Model + +````@example example_synthetic_lstm +""" + RbQ10(; ta, Q10, rb, tref=15.0f0) + +Respiration model with Q10 temperature sensitivity. + +- `ta`: air temperature [°C] +- `Q10`: temperature sensitivity factor [-] +- `rb`: basal respiration rate [μmol/m²/s] +- `tref`: reference temperature [°C] (default: 15.0) +""" +function RbQ10(; ta, Q10, rb, tref = 15.0f0) + reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref)) + return (; reco, Q10, rb) +end +```` + +## 5. Define Model Parameters + +Parameter specification: (default, lower_bound, upper_bound) + +````@example example_synthetic_lstm +parameters = ( + rb = (3.0f0, 0.0f0, 13.0f0), # Basal respiration [μmol/m²/s] + Q10 = (2.0f0, 1.0f0, 4.0f0), # Temperature sensitivity factor [-] +) +```` + +## 6. Configure Hybrid Model Components + +Define input variables +Forcing variables (temperature) + +````@example example_synthetic_lstm +forcing = [:ta] +```` + +Predictor variables (solar radiation, and its derivative) + +````@example example_synthetic_lstm +predictors = [:sw_pot, :dsw_pot] +```` + +Target variable (respiration) + +````@example example_synthetic_lstm +target = [:reco] +```` + +Parameter classification +Global parameters (same for all samples) + +````@example example_synthetic_lstm +global_param_names = [:Q10] +```` + +Neural network predicted parameters + +````@example example_synthetic_lstm +neural_param_names = [:rb] +```` + +## 7. Construct LSTM Hybrid Model + +Create LSTM hybrid model using the unified constructor + +````@example example_synthetic_lstm +hlstm = constructHybridModel( + predictors, + forcing, + target, + RbQ10, + parameters, + neural_param_names, + global_param_names, + hidden_layers = NN_Memory, # Neural network architecture + scale_nn_outputs = true, # Scale neural network outputs + input_batchnorm = false # Apply batch normalization to inputs +) +```` + +## 8. Data Preparation Steps (Demonstration) + +The following steps demonstrate what happens under the hood during training. +In practice, you can skip to Section 9 and use the `train` function directly. + +:KeyedArray and :DimArray are supported + +````@example example_synthetic_lstm +x, y = prepare_data(hlstm, df, array_type = :DimArray); +nothing #hide +```` + +New split_into_sequences with input_window, output_window, shift and lead_time +for many-to-one, many-to-many, and different prediction lead times and overlap + +````@example example_synthetic_lstm +xs, ys = split_into_sequences(x, y; input_window = 20, output_window = 2, shift = 1, lead_time = 0); +ys_nan = .!isnan.(ys); +nothing #hide +```` + +Split data as in train + +````@example example_synthetic_lstm +sdf = split_data(df, hlstm, sequence_kwargs = (; input_window = 10, output_window = 3, shift = 1, lead_time = 1)); + +typeof(sdf); +(x_train, y_train), (x_val, y_val) = sdf; +x_train; +y_train; +y_train_nan = .!isnan.(y_train); +nothing #hide +```` + +Put into train loader to compose minibatches + +````@example example_synthetic_lstm +train_dl = EasyHybrid.DataLoader((x_train, y_train); batchsize = 32); +nothing #hide +```` + +Run hybrid model forwards + +````@example example_synthetic_lstm +x_first = first(train_dl)[1]; +y_first = first(train_dl)[2]; + +ps, st = Lux.setup(Random.default_rng(), hlstm); +frun = hlstm(x_first, ps, st); +nothing #hide +```` + +Extract predicted yhat + +````@example example_synthetic_lstm +reco_mod = frun[1].reco; +nothing #hide +```` + +Bring observations in same shape + +````@example example_synthetic_lstm +reco_obs = dropdims(y_first, dims = 1); +reco_nan = .!isnan.(reco_obs); +nothing #hide +```` + +Compute loss + +````@example example_synthetic_lstm +EasyHybrid.compute_loss(hlstm, ps, st, (x_train, (y_train, y_train_nan)), logging = LoggingLoss(train_mode = true)); +nothing #hide +```` + +## 9. Train LSTM Hybrid Model + +````@example example_synthetic_lstm +out_lstm = train( + hlstm, + df, + (); + nepochs = 2, # Number of training epochs + batchsize = 512, # Batch size for training + opt = AdamW(0.1), # Optimizer and learning rate + monitor_names = [:rb, :Q10], # Parameters to monitor during training + yscale = identity, # Scaling for outputs + shuffleobs = false, + loss_types = [:mse, :nse], + sequence_kwargs = (; input_window = 10, output_window = 4), + plotting = false, + array_type = :DimArray +); +nothing #hide +```` + +## 10. Train Single NN Hybrid Model (Optional) + +For comparison, we can also train a hybrid model with a standard feedforward neural network + +````@example example_synthetic_lstm +hm = constructHybridModel( + predictors, + forcing, + target, + RbQ10, + parameters, + neural_param_names, + global_param_names, + hidden_layers = NN, # Neural network architecture + scale_nn_outputs = true, # Scale neural network outputs + input_batchnorm = false, # Apply batch normalization to inputs +) +```` + +Train the hybrid model + +````@example example_synthetic_lstm +single_nn_out = train( + hm, + df, + (); + nepochs = 3, # Number of training epochs + batchsize = 512, # Batch size for training + opt = AdamW(0.1), # Optimizer and learning rate + monitor_names = [:rb, :Q10], # Parameters to monitor during training + yscale = identity, # Scaling for outputs + shuffleobs = false, + loss_types = [:mse, :nse], + array_type = :DimArray +); +nothing #hide +```` + diff --git a/projects/RbQ10/Q10_lstm.jl b/projects/RbQ10/Q10_lstm.jl new file mode 100644 index 00000000..7a0e04f7 --- /dev/null +++ b/projects/RbQ10/Q10_lstm.jl @@ -0,0 +1,138 @@ +# activate the project's environment and instantiate dependencies +using Pkg +Pkg.activate("projects/RbQ10") +Pkg.develop(path = pwd()) +Pkg.instantiate() + +# start using the package +using EasyHybrid +using EasyHybrid.MLUtils +using Random +# for Plotting +using GLMakie +using AlgebraOfGraphics +using EasyHybrid.AxisKeys +function split_into_sequences(xin, y_target; window_size = 8) + features = size(xin, 1) + xdata = slidingwindow(xin, size = window_size, stride = 1) + # get the target values corresponding to the sliding windows, + # elements of `ydata` correspond to the last sliding window element. + #ydata = y_target[window_size:length(xdata) + window_size - 1] + ydata = slidingwindow(y_target, size = window_size, stride = 1) + + xwindowed = zeros(Float32, features, window_size, length(ydata)) + #ywindowed = zeros(Float32, 1, 1, length(ydata)) + ywindowed = zeros(Float32, 1, window_size, length(ydata)) + for i in 1:length(ydata) + xwindowed[:, :, i] = getobs(xdata, i) + ywindowed[:, :, i] = getobs(ydata, i) + end + xwindowed = KeyedArray(xwindowed; row = xin.row, window = 1:window_size, col = 1:length(ydata)) + ywindowed = KeyedArray(ywindowed; row = 1:1, window = 1:window_size, col = 1:length(ydata)) + return xwindowed, ywindowed +end + +xwindowed, ydata = split_into_sequences(ds_p_f, ds_t; window_size = 8) + +script_dir = @__DIR__ +include(joinpath(script_dir, "data", "prec_process_data.jl")) + +# Common data preprocessing +df = dfall[!, Not(:timesteps)] +ds_keyed = to_keyedArray(Float32.(df)) + +target_names = [:R_soil] +forcing_names = [:cham_temp_filled] +predictor_names = [:moisture_filled, :rgpot2] + +# Define neural network +NN = Chain(Dense(2, 15, Lux.relu), Dense(15, 15, Lux.relu), Dense(15, 1)) +# NN(rand(Float32, 2,1)) #! needs to be instantiated + + +# instantiate Hybrid Model +RbQ10 = RespirationRbQ10(NN, predictor_names, target_names, forcing_names, 2.5f0) # ? do different initial Q10s +# train model +out = train(RbQ10, ds_keyed, (:Q10,); nepochs = 10, batchsize = 512, opt = Adam(0.01)); + +## +output_file = joinpath(@__DIR__, "output_tmp/trained_model.jld2") +all_groups = get_all_groups(output_file) + +# ? Let's use `Recurrence` to stack LSTM cells and deal with sequences and batching at the same time! + +NN_Memory = Chain( + Recurrence(LSTMCell(2 => 6), return_sequence = true), + Recurrence(LSTMCell(6 => 2), return_sequence = false), + Dense(2 => 1) +) + +# c_lstm = LSTMCell(4 => 6) +# ps, st = Lux.setup(rng, c_lstm) +# (y, carry), st_lstm = c_lstm(rand(Float32, 4), ps, st) +# m_lstm = Recurrence(LSTMCell(4 => 6), return_sequence=false) + +rng = Random.default_rng(1234) +ps, st = Lux.setup(rng, NN_Memory) +mock_data = rand(Float32, 2, 8, 5) #! n_features, n_timesteps (window size), n_samples (batch size) +y_, st_ = NN_Memory(mock_data, ps, st) + +RbQ10_Memory = RespirationRbQ10(NN_Memory, predictor_names, target_names, forcing_names, 2.5f0) # ? do different initial Q10s + +## legacy +# ? test lossfn +ps, st = LuxCore.setup(Random.default_rng(), RbQ10_Memory) +# the Tuple `ds_p, ds_t` is later used for batching in the `dataloader`. +ds_p_f, ds_t = EasyHybrid.prepare_data(RbQ10_Memory, ds_keyed) +ds_t_nan = .!isnan.(ds_t) + +# +xwindowed, ydata = split_into_sequences(ds_p_f, ds_t; window_size = 8) +ds_wt = ydata +ds_wt_nan = .!isnan.(ds_wt) +# xdata = slidingwindow(ds_p_f, size = 8, stride = 1) +# getobs(xdata, 1) +# split_test = rand(Float32, 3, 8, 5) +using EasyHybrid.AxisKeys +# A = KeyedArray(split_test; row=ds_p_f.row, window=1:8, col=1:5) + +broadcast_layer2 = @compact(; layer = Dense(2 => 1)) do x::Union{NTuple{<:AbstractArray}, AbstractVector{<:AbstractArray}} + y = map(layer, x) + @return permutedims(stack(y; dims = 3), (1, 3, 2)) +end + +NN_Memory = Chain( + Recurrence(LSTMCell(2 => 2), return_sequence = true), + broadcast_layer2 +) + + +RbQ10_Memory = RespirationRbQ10(NN_Memory, predictor_names, target_names, forcing_names, 2.5f0) # ? do different initial Q10s + +#? this sets up initial ps for the hybrid model version +rng = Random.default_rng(1234) +ps, st = Lux.setup(rng, RbQ10_Memory) + +xdl = EasyHybrid.DataLoader(xwindowed; batchsize = 512) +ydl = EasyHybrid.DataLoader((ds_wt, ds_wt_nan); batchsize = 512) + +sdf = RbQ10_Memory(first(xdl), ps, st) + +mid1 = first(xdl)(RbQ10_Memory.forcing) +mid1[:, end, :] + + +o1, st1 = LuxCore.apply(RbQ10_Memory.NN, first(xdl)(RbQ10_Memory.predictors), ps.ps, st.st) + +sdf[1].R_soil +sdf[1].Rb +first(ydl)[1] + +ls = EasyHybrid.lossfn(RbQ10_Memory, first(xdl), first(ydl), ps, st, LoggingLoss()) + +ls_logs = EasyHybrid.lossfn(RbQ10_Memory, xwindowed, (ds_wt, ds_wt_nan), ps, st, LoggingLoss(train_mode = false)) + + +# p = A(RbQ10_Memory.predictors) +# x = Array(A(RbQ10_Memory.forcing)) # don't propagate names after this +# Rb, st = LuxCore.apply(RbQ10_Memory.NN, p, ps.ps, st.st) diff --git a/projects/book_chapter/Project.toml b/projects/book_chapter/Project.toml index 2bccc509..34cb44cb 100644 --- a/projects/book_chapter/Project.toml +++ b/projects/book_chapter/Project.toml @@ -3,6 +3,7 @@ AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3" +GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab" diff --git a/projects/book_chapter/example_synthetic.jl b/projects/book_chapter/example_synthetic.jl index 642b2c88..c0bf3457 100644 --- a/projects/book_chapter/example_synthetic.jl +++ b/projects/book_chapter/example_synthetic.jl @@ -39,7 +39,7 @@ ka = to_keyedArray(dfnot) # DimensionalData mat = Array(Matrix(dfnot)') -da = DimArray(mat, (Dim{:col}(Symbol.(names(dfnot))), Dim{:row}(1:size(dfnot, 1)))) +da = DimArray(mat, (inout = Symbol.(names(dfnot)), batch_size = 1:size(dfnot, 1))) # ============================================================================= # Define the Physical Model @@ -81,7 +81,7 @@ neural_param_names = [:rb] # Neural network predicted parameters # ============================================================================= # Single NN Hybrid Model Training # ============================================================================= -using WGLMakie +using GLMakie # Create single NN hybrid model using the unified constructor predictors_single_nn = [:sw_pot, :dsw_pot] # Predictor variables (solar radiation, and its derivative) @@ -119,7 +119,9 @@ single_nn_out = train( yscale = identity, # Scaling for outputs shuffleobs = true, loss_types = [:mse, :nse], - extra_loss = extra_loss + extra_loss = extra_loss, + array_type = :KeyedArray, + plotting = false ) # ============================================================================= @@ -149,7 +151,9 @@ single_nn_out = train( opt = AdamW(0.1), # Optimizer and learning rate monitor_names = [:rb, :Q10], # Parameters to monitor during training yscale = identity, # Scaling for outputs - shuffleobs = true + shuffleobs = true, + array_type = :DimArray, + plotting = false ) LuxCore.testmode(single_nn_out.st) diff --git a/src/EasyHybrid.jl b/src/EasyHybrid.jl index 8fb3ed43..0f76e195 100644 --- a/src/EasyHybrid.jl +++ b/src/EasyHybrid.jl @@ -5,14 +5,18 @@ EasyHybrid is a Julia package for hybrid machine learning models, combining neur """ module EasyHybrid -using AxisKeys: AxisKeys, KeyedArray, axiskeys, wrapdims +using AxisKeys: AxisKeys, KeyedArray, Key, axiskeys, wrapdims using CSV: CSV using Chain: @chain using ChainRulesCore: ChainRulesCore using ComponentArrays: ComponentArrays, ComponentArray using DataFrameMacros: DataFrameMacros, @transform using DataFrames: DataFrames, DataFrame, GroupedDataFrame, Missing, coalesce, mapcols, select, missing, All -using DimensionalData: DimensionalData, AbstractDimArray, At, dims, groupby +using DimensionalData: DimensionalData, AbstractDimArray, Dim, DimArray, dims, groupby, lookup, At +# Extend axiskeys to work with DimArrays (delegates to lookup) +AxisKeys.axiskeys(da::AbstractDimArray) = Tuple(lookup(da, d) for d in dims(da)) +AxisKeys.axiskeys(da::AbstractDimArray, i::Int) = lookup(da, dims(da)[i]) +AxisKeys.axiskeys(da::AbstractDimArray, name::Symbol) = lookup(da, name) using Downloads: Downloads using Hyperopt: Hyperopt, Hyperoptimizer using JLD2: JLD2, jldopen @@ -39,6 +43,7 @@ using Static: False, True using Statistics using DataFrames using CSV + using AxisKeys: axiskeys using OptimizationOptimisers: OptimizationOptimisers, Optimisers, Adam, AdamW, RMSProp using ComponentArrays end diff --git a/src/models/GenericHybridModel.jl b/src/models/GenericHybridModel.jl index 5c74aa96..2587d208 100644 --- a/src/models/GenericHybridModel.jl +++ b/src/models/GenericHybridModel.jl @@ -380,7 +380,7 @@ function (m::SingleNNHybridModel)(ds_k::Union{KeyedArray, AbstractDimArray}, ps, # 3) scale NN parameters (handle empty case) if !isempty(m.neural_param_names) nn_out, st_nn = LuxCore.apply(m.NN, predictors, ps.ps, st.st_nn) - nn_cols = eachrow(nn_out) + nn_cols = eachslice(nn_out, dims = 1) nn_params = NamedTuple(zip(m.neural_param_names, nn_cols)) # Use appropriate scaling based on setting @@ -533,7 +533,6 @@ function (m::MultiNNHybridModel)(df::DataFrame, ps, st) all_data = to_keyedArray(df) x, _ = prepare_data(m, all_data) - @show typeof(x) out, _ = m(x, ps, LuxCore.testmode(st)) dfnew = copy(df) for k in keys(out) diff --git a/src/models/NNModels.jl b/src/models/NNModels.jl index fddfec29..83ea9fda 100644 --- a/src/models/NNModels.jl +++ b/src/models/NNModels.jl @@ -1,4 +1,4 @@ -export SingleNNModel, MultiNNModel, constructNNModel, prepare_hidden_chain +export SingleNNModel, MultiNNModel, constructNNModel, prepare_hidden_chain, BroadcastLayer, RecurrenceOutputDense using Lux, LuxCore using ..EasyHybrid: hard_sigmoid @@ -16,6 +16,66 @@ struct SingleNNModel <: LuxCore.AbstractLuxContainerLayer{ scale_nn_outputs::Bool end +""" + RecurrenceOutputDense(in_dims => out_dims, [activation]) + +A layer that wraps a Dense layer to handle sequence outputs from Recurrence layers. + +When a Recurrence layer has `return_sequence=true`, it outputs a tuple/vector of arrays +(one per timestep). This layer broadcasts the Dense operation over each timestep and +reshapes the result to `(features, timesteps, batch)` format. + +# Arguments +- `in_dims::Int`: Input dimension (should match Recurrence output dimension) +- `out_dims::Int`: Output dimension +- `activation`: Activation function (default: `identity`) + +# Example +```julia +# Instead of manually creating: +broadcast_layer = @compact(; layer = Dense(15 => 15)) do x + y = map(layer, x) + @return permutedims(stack(y; dims = 3), (1, 3, 2)) +end + +# Simply use: +Chain( + Recurrence(LSTMCell(15 => 15), return_sequence = true), + RecurrenceOutputDense(15 => 15) +) +``` +""" +struct RecurrenceOutputDense{D <: Dense} <: LuxCore.AbstractLuxWrapperLayer{:layer} + layer::D +end + +function RecurrenceOutputDense(mapping::Pair{Int, Int}, activation = identity) + return RecurrenceOutputDense(Dense(mapping.first, mapping.second, activation)) +end + +function RecurrenceOutputDense(in_dims::Int, out_dims::Int, activation = identity) + return RecurrenceOutputDense(Dense(in_dims, out_dims, activation)) +end + +# Handle tuple output from Recurrence (return_sequence = true) +function (m::RecurrenceOutputDense)(x::NTuple{N, <:AbstractArray}, ps, st) where {N} + y = map(xi -> first(LuxCore.apply(m.layer, xi, ps, st)), x) + result = permutedims(stack(y; dims = 3), (1, 3, 2)) + return result, st +end + +# Handle vector output from Recurrence (return_sequence = true) +function (m::RecurrenceOutputDense)(x::AbstractVector{<:AbstractArray}, ps, st) + y = map(xi -> first(LuxCore.apply(m.layer, xi, ps, st)), x) + result = permutedims(stack(y; dims = 3), (1, 3, 2)) + return result, st +end + +# Fallback for regular array input (non-sequence mode) +function (m::RecurrenceOutputDense)(x::AbstractArray, ps, st) + return LuxCore.apply(m.layer, x, ps, st) +end + """ prepare_hidden_chain(hidden_layers, in_dim, out_dim; activation, input_batchnorm=false) @@ -26,6 +86,8 @@ Construct a neural network `Chain` for use in NN models. - If a `Vector{Int}`, specifies the sizes of each hidden layer. For example, `[32, 16]` creates two hidden layers with 32 and 16 units, respectively. - If a `Chain`, the user provides a pre-built chain of hidden layers (excluding input/output layers). + If the chain ends with a `Recurrence` layer, a `RecurrenceOutputDense` layer is automatically + added to handle the sequence output format. - `in_dim::Int`: Number of input features (input dimension). - `out_dim::Int`: Number of output features (output dimension). - `activation`: Activation function to use in hidden layers (default: `tanh`). @@ -36,9 +98,21 @@ Construct a neural network `Chain` for use in NN models. - Optional input batch normalization (if `input_batchnorm=true`) - Input layer: `Dense(in_dim, h₁, activation)` where `h₁` is the first hidden size - Hidden layers: either user-supplied `Chain` or constructed from `hidden_layers` + - If last hidden layer is a `Recurrence`, a `RecurrenceOutputDense` is added to handle sequence output - Output layer: `Dense(hₖ, out_dim)` where `hₖ` is the last hidden size where `h₁` is the first hidden size and `hₖ` the last. + +# Example with Recurrence (LSTM) +```julia +# User only needs to define: +NN_Memory = Chain( + Recurrence(LSTMCell(15 => 15), return_sequence = true), +) + +# The function automatically adds the RecurrenceOutputDense layer to handle sequence output +model = constructHybridModel(..., hidden_layers = NN_Memory, ...) +``` """ function prepare_hidden_chain( hidden_layers::Union{Vector{Int}, Chain}, @@ -49,26 +123,81 @@ function prepare_hidden_chain( ) if hidden_layers isa Chain # user gave a chain of hidden layers only - first_h = hidden_layers[1].out_dims - last_h = hidden_layers[end].out_dims - return Chain( - input_batchnorm ? BatchNorm(in_dim, affine = false) : identity, - Dense(in_dim, first_h, activation), - hidden_layers.layers..., - Dense(last_h, out_dim) - ) + # Helper to safely extract dimensions from layers + function get_layer_dim(l, type) + if type == :input + hasproperty(l, :in_dims) && return l.in_dims + (l isa BatchNorm && hasproperty(l, :dims)) && return l.dims + (l isa Recurrence && hasproperty(l.cell, :in_dims)) && return l.cell.in_dims + (l isa CompactLuxLayer && hasproperty(l.layers, :in_dims)) && return l.layers.in_dims + elseif type == :output + hasproperty(l, :out_dims) && return l.out_dims + (l isa BatchNorm && hasproperty(l, :dims)) && return l.dims + (l isa Recurrence && hasproperty(l.cell, :out_dims)) && return l.cell.out_dims + (l isa CompactLuxLayer && hasproperty(l.layers, :out_dims)) && return l.layers.out_dims + end + return nothing + end + + # Check if last layer is a Recurrence layer (needs special handling for sequence output) + # In this framework, we ALWAYS assume return_sequence=true for Recurrence layers + # (this is the EasyHybrid convention, regardless of Lux's default) + function is_sequence_recurrence(layer) + return layer isa Recurrence + end + + last_layer = hidden_layers.layers[end] + ends_with_sequence_recurrence = is_sequence_recurrence(last_layer) + + # Determine first_h by searching forward + first_h = nothing + for i in 1:length(hidden_layers) + d = get_layer_dim(hidden_layers[i], :input) + if !isnothing(d) + first_h = d + break + end + end + isnothing(first_h) && error("Could not determine input dimension of hidden_layers Chain.") + + # Determine last_h by searching backward + last_h = nothing + for i in length(hidden_layers):-1:1 + d = get_layer_dim(hidden_layers[i], :output) + if !isnothing(d) + last_h = d + break + end + end + isnothing(last_h) && error("Could not determine output dimension of hidden_layers Chain.") + + if ends_with_sequence_recurrence + # Chain ends with Recurrence layer (return_sequence=true) - add RecurrenceOutputDense to handle sequence output + return Chain( + input_batchnorm ? BatchNorm(in_dim, affine = false) : identity, + Dense(in_dim, first_h, activation), + hidden_layers.layers..., + RecurrenceOutputDense(last_h => last_h, activation), + Dense(last_h, out_dim) + ) + else + return Chain( + input_batchnorm ? BatchNorm(in_dim, affine = false) : identity, + Dense(in_dim, first_h, activation), + hidden_layers.layers..., + Dense(last_h, out_dim) + ) + end else # user gave a vector of hidden‐layer sizes hs = hidden_layers - # build the hidden‐to‐hidden part - hidden_chain = length(hs) > 1 ? - Chain((Dense(hs[i], hs[i + 1], activation) for i in 1:(length(hs) - 1))...) : - Chain() + isempty(hs) && return Chain() + in_dim == 0 && return Chain() return Chain( input_batchnorm ? BatchNorm(in_dim, affine = false) : identity, Dense(in_dim, hs[1], activation), - hidden_chain.layers..., + (Dense(hs[i], hs[i + 1], activation) for i in 1:(length(hs) - 1))..., Dense(hs[end], out_dim) ) end @@ -236,3 +365,26 @@ function Base.show(io::IO, ::MIME"text/plain", m::MultiNNModel) end return println("scale NN outputs: ", m.scale_nn_outputs) end + +struct BroadcastLayer{T <: NamedTuple} <: LuxCore.AbstractLuxContainerLayer{(:layers,)} + layers::T +end + +function BroadcastLayer(layers...) + for l in layers + if !iszero(LuxCore.statelength(l)) + throw(ArgumentError("Stateful layer `$l` are not supported for `BroadcastLayer`.")) + end + end + names = ntuple(i -> Symbol("layer_$i"), length(layers)) + return BroadcastLayer(NamedTuple{names}(layers)) +end + +BroadcastLayer(; kwargs...) = BroadcastLayer(connection, (; kwargs...)) + +function (m::BroadcastLayer)(x, ps, st::NamedTuple{names}) where {names} + results = (first ∘ Lux.apply).(values(m.layers), x, values(ps), values(st)) + return results, st +end + +Base.keys(m::BroadcastLayer) = Base.keys(getfield(m, :layers)) diff --git a/src/models/Respiration_Rb_Q10.jl b/src/models/Respiration_Rb_Q10.jl index a0cf0d5a..0f211130 100644 --- a/src/models/Respiration_Rb_Q10.jl +++ b/src/models/Respiration_Rb_Q10.jl @@ -60,8 +60,8 @@ function (hm::RespirationRbQ10)(ds_k, ps, st::NamedTuple) end function (hm::RespirationRbQ10)(ds_k::AbstractDimArray, ps, st::NamedTuple) - p = ds_k[col = At(hm.predictors)] - x = Array(ds_k[col = At(hm.forcing)]) # don't propagate names after this + p = ds_k[inout = At(hm.predictors)] # No @view - needs to be differentiable + x = Array(ds_k[inout = At(hm.forcing)]) # No @view - needs to be differentiable Rb, stQ10 = LuxCore.apply(hm.NN, p, ps.ps, st.st) #! NN(αᵢ(t)) ≡ Rb(T(t), M(t)) diff --git a/src/plotrecipes.jl b/src/plotrecipes.jl index 8804e50a..7927e8f3 100644 --- a/src/plotrecipes.jl +++ b/src/plotrecipes.jl @@ -114,11 +114,11 @@ function to_obs_tuple(y, target_names) end function to_tuple(y::KeyedArray, target_names) - return (; (t => y(t) for t in target_names)...) # observations are fixed, no Observables are needed! + return (; (t => y(inout = t) for t in target_names)...) # observations are fixed, no Observables are needed! end function to_tuple(y::AbstractDimArray, target_names) - return (; (t => Array(y[col = At(t)]) for t in target_names)...) # observations are fixed, no Observables are needed! + return (; (t => Array(y[inout = At(t)]) for t in target_names)...) # observations are fixed, no Observables are needed! end function monitor_to_obs(ŷ, monitor_names; cuts = (0.25, 0.5, 0.75)) diff --git a/src/train.jl b/src/train.jl index 282f389c..b37550c0 100644 --- a/src/train.jl +++ b/src/train.jl @@ -1,4 +1,4 @@ -export train, TrainResults, prepare_data, split_data +export train, TrainResults, prepare_data, split_data, split_into_sequences, windowed_dataset # beneficial for plotting based on type TrainResults? struct TrainResults train_history @@ -39,7 +39,8 @@ Default output file is `trained_model.jld2` at the current working directory und - `loss_types`: A vector of loss types to compute during training (default: `[:mse, :r2]`). The first entry is used for plotting in the dynamic trainboard. This loss can be increasing (e.g. NSE) or decreasing (e.g. RMSE). - `agg`: The aggregation function to apply to the computed losses (default: `sum`). -## Data Handling (passed via kwargs): +## Data Handling: +- `array_type`: Array type for data conversion from DataFrame: `:DimArray` (default) or `:KeyedArray`. - `shuffleobs`: Whether to shuffle the training data (default: false). - `split_by_id`: Column name or function to split data by ID (default: nothing -> no ID-based splitting). - `split_data_at`: Fraction of data to use for training when splitting (default: 0.8). @@ -71,6 +72,8 @@ function train( patience = typemax(Int), autodiff_backend = AutoZygote(), return_gradients = True(), + # Array type for data conversion + array_type = :KeyedArray, # :DimArray or :KeyedArray # Loss and evaluation training_loss = :mse, loss_types = [:mse, :r2], @@ -116,7 +119,7 @@ function train( Random.seed!(random_seed) end - (x_train, y_train), (x_val, y_val) = split_data(data, hybridModel; kwargs...) + (x_train, y_train), (x_val, y_val) = split_data(data, hybridModel; array_type = array_type, kwargs...) train_loader = DataLoader((x_train, y_train), batchsize = batchsize, shuffle = true) @@ -411,9 +414,22 @@ function split_data( val_fold::Union{Nothing, Int} = nothing, shuffleobs::Bool = false, split_data_at::Real = 0.8, + sequence_kwargs::Union{Nothing, NamedTuple} = nothing, + array_type::Symbol = :KeyedArray, kwargs... ) - data_ = prepare_data(hybridModel, data) + data_ = prepare_data(hybridModel, data; array_type = array_type) + + if sequence_kwargs !== nothing + x_keyed, y_keyed = data_ + sis_default = (; input_window = 10, output_window = 1, shift = 1, lead_time = 1) + sis = merge(sis_default, sequence_kwargs) + @info "Using split_into_sequences: $sis" + x_all, y_all = split_into_sequences(x_keyed, y_keyed; sis.input_window, sis.output_window, sis.shift, sis.lead_time) + else + x_all, y_all = data_ + end + if split_by_id !== nothing && folds !== nothing @@ -431,9 +447,8 @@ function split_data( @info "Number of unique $(split_by_id): $(length(unique_ids))" @info "Train IDs: $(length(train_ids)) | Val IDs: $(length(val_ids))" - x_all, y_all = data_ - x_train, y_train = view(x_all, :, train_idx), view(y_all, :, train_idx) - x_val, y_val = view(x_all, :, val_idx), view(y_all, :, val_idx) + x_train, y_train = view_end_dim(x_all, train_idx), view_end_dim(y_all, train_idx) + x_val, y_val = view_end_dim(x_all, val_idx), view_end_dim(y_all, val_idx) return (x_train, y_train), (x_val, y_val) elseif folds !== nothing || val_fold !== nothing @@ -441,7 +456,6 @@ function split_data( @assert val_fold !== nothing "Provide val_fold when using folds." @assert folds !== nothing "Provide folds when using val_fold." @warn "shuffleobs is not supported when using folds and val_fold, this will be ignored and should be done during fold constructions" - x_all, y_all = data_ f = isa(folds, Symbol) ? getbyname(data, folds) : folds n = size(x_all, 2) @assert length(f) == n "length(folds) ($(length(f))) must equal number of samples/columns ($n)." @@ -453,13 +467,13 @@ function split_data( @info "K-fold via external assignments: val_fold=$val_fold → train=$(length(train_idx)) val=$(length(val_idx))" - x_train, y_train = view(x_all, :, train_idx), view(y_all, :, train_idx) - x_val, y_val = view(x_all, :, val_idx), view(y_all, :, val_idx) + x_train, y_train = view_end_dim(x_all, train_idx), view_end_dim(y_all, train_idx) + x_val, y_val = view_end_dim(x_all, val_idx), view_end_dim(y_all, val_idx) return (x_train, y_train), (x_val, y_val) else # --- Fallback: simple random/chronological split of prepared data --- - (x_train, y_train), (x_val, y_val) = splitobs(data_; at = split_data_at, shuffle = shuffleobs) + (x_train, y_train), (x_val, y_val) = splitobs((x_all, y_all); at = split_data_at, shuffle = shuffleobs) return (x_train, y_train), (x_val, y_val) end end @@ -493,12 +507,19 @@ Split data into training and validation sets, either randomly, by grouping by ID """ function split_data end -function prepare_data(hm, data::KeyedArray) +function prepare_data(hm, data::KeyedArray; array_type = :KeyedArray) predictors_forcing, targets = get_prediction_target_names(hm) + # KeyedArray: use () syntax for views that are differentiable return (data(predictors_forcing), data(targets)) end -function prepare_data(hm, data::DataFrame) +function prepare_data(hm, data::AbstractDimArray; array_type = :DimArray) + predictors_forcing, targets = get_prediction_target_names(hm) + # DimArray: use [] syntax (copies, but differentiable) + return (data[inout = At(predictors_forcing)], data[inout = At(targets)]) +end + +function prepare_data(hm, data::DataFrame; array_type = :KeyedArray) predictors_forcing, targets = get_prediction_target_names(hm) all_predictor_cols = unique(vcat(values(predictors_forcing)...)) @@ -522,17 +543,16 @@ function prepare_data(hm, data::DataFrame) keep = .!mask_missing_predforce .& mask_at_least_one_target sdf = sdf[keep, col_to_select] - # Convert to Float32 and to your keyed array - ds_keyed = to_keyedArray(Float32.(sdf)) - return prepare_data(hm, ds_keyed) -end - -function prepare_data(hm, data::AbstractDimArray) - predictors_forcing, targets = get_prediction_target_names(hm) - return (data[col = At(predictors_forcing)], data[col = At(targets)]) # TODO check what this should be rows or cols, I would say rows, but maybe it does not matter + # Convert to Float32 and to the specified array type + if array_type == :KeyedArray + ds = to_keyedArray(Float32.(sdf)) + else + ds = to_dimArray(Float32.(sdf)) + end + return prepare_data(hm, ds; array_type = array_type) end -function prepare_data(hm, data::Tuple) +function prepare_data(hm, data::Tuple; array_type = :DimArray) return data end @@ -610,6 +630,70 @@ function getbyname(df::DataFrame, name::Symbol) return df[!, name] end -function getbyname(ka::AxisKeys.KeyedArray, name::Symbol) - return ka(name) +function getbyname(ka::Union{KeyedArray, AbstractDimArray}, name::Symbol) + return @view ka[inout = At(name)] +end + +function split_into_sequences(x, y; input_window = 5, output_window = 1, shift = 1, lead_time = 1) + ndims(x) == 2 || throw(ArgumentError("expected x to be (feature, time); got ndims(x) = $(ndims(x))")) + ndims(y) == 2 || throw(ArgumentError("expected y to be (target, time); got ndims(y) = $(ndims(y))")) + + Lx, Ly = size(x, 2), size(y, 2) + Lx == Ly || throw(ArgumentError("x and y must have same time length; got $Lx vs $Ly")) + lead_time ≥ 0 || throw(ArgumentError("lead_time must be ≥ 0 (0 = instantaneous end)")) + + nfeat, ntarget = size(x, 1), size(y, 1) + L = Lx + + featkeys = axiskeys(x, 1) + timekeys = axiskeys(x, 2) + targetkeys = axiskeys(y, 1) + + lead_start = lead_time - output_window + 1 + + lag_keys = Symbol.(["x$(lag)" for lag in (input_window + lead_time - 1):-1:lead_time]) + lead_keys = Symbol.(["_y$(lead)" for lead in ((output_window - 1):-1:0)]) + lead_keys = Symbol.(lag_keys[(end - length(lead_keys) + 1):end], lead_keys) + lag_keys[(end - length(lead_keys) + 1):end] .= lead_keys + + sx_min = max(1, 1 - (input_window + lead_time - output_window)) + sx_max = L - input_window - lead_time + 1 + sx_min <= sx_max || throw(ArgumentError("windows too long for series length")) + + sx_vals = collect(sx_min:shift:sx_max) + num_samples = length(sx_vals) + num_samples ≥ 1 || throw(ArgumentError("no samples with given shift/windows")) + + samplekeys = timekeys[sx_vals] + + Xd = zeros(Float32, nfeat, input_window, num_samples) + Yd = zeros(Float32, ntarget, output_window, num_samples) + + @inbounds @views for (ii, sx) in enumerate(sx_vals) + ex = sx + input_window - 1 + sy = ex + lead_start + ey = ex + lead_time + Xd[:, :, ii] .= x[:, sx:ex] + Yd[:, :, ii] .= y[:, sy:ey] + end + if x isa KeyedArray + Xk = KeyedArray(Xd; inout = featkeys, time = lag_keys, batch_size = samplekeys) + Yk = KeyedArray(Yd; inout = targetkeys, time = lead_keys, batch_size = samplekeys) + return Xk, Yk + elseif x isa AbstractDimArray + Xk = DimArray(Xd, (inout = featkeys, time = lag_keys, batch_size = samplekeys)) + Yk = DimArray(Yd, (inout = targetkeys, time = lead_keys, batch_size = samplekeys)) + return Xk, Yk + else + throw(ArgumentError("expected Xd to be KeyedArray or AbstractDimArray; got $(typeof(Xd))")) + end +end + + +function view_end_dim(x_all::Union{KeyedArray{Float32, 2}, AbstractDimArray{Float32, 2}}, idx) + return view(x_all, :, idx) +end + +function view_end_dim(x_all::Union{KeyedArray{Float32, 3}, AbstractDimArray{Float32, 3}}, idx) + return view(x_all, :, :, idx) end diff --git a/src/utils/compute_loss.jl b/src/utils/compute_loss.jl index a1e63151..021897cc 100644 --- a/src/utils/compute_loss.jl +++ b/src/utils/compute_loss.jl @@ -53,6 +53,7 @@ function _compute_loss(ŷ, y, y_nan, targets, loss_spec, agg::Function) end function _compute_loss(ŷ, y, y_nan, targets, loss_types::Vector, agg::Function) + out_loss_types = [ begin losses = assemble_loss(ŷ, y, y_nan, targets, loss_type) @@ -84,9 +85,31 @@ Returns a single loss value if `loss_spec` is provided, or a NamedTuple of losse """ function _compute_loss end +# Wrapper for time-based subsetting - dispatches on array type for differentiability +_select_time(ŷ_t::KeyedArray, time_keys) = ŷ_t(time = time_keys) # KeyedArray: () syntax - view & differentiable +_select_time(ŷ_t::AbstractDimArray, time_keys) = ŷ_t[time = At(time_keys)] # DimArray: [] syntax - copy & differentiable + + +# For 2D y_t (from 3D y): needs time subsetting +# y_t has dims (time, batch_size), ŷ[target] has (time=input_window, batch_size) +# We subset ŷ to match y_t's time dimension (output_window) +_get_target_ŷ(ŷ, y_t::Union{KeyedArray{T, 2}, AbstractDimArray{T, 2}}, target) where {T} = + _select_time(ŷ[target], axiskeys(y_t, :time)) + +# For 1D y_t (from 2D y): no time subsetting needed +_get_target_ŷ(ŷ, y_t::Union{KeyedArray{T, 1}, AbstractDimArray{T, 1}}, target) where {T} = + ŷ[target] + +_get_target_ŷ(ŷ, y_t, target) = + ŷ[target] + function assemble_loss(ŷ, y, y_nan, targets, loss_spec) return [ - _apply_loss(ŷ[target], _get_target_y(y, target), _get_target_nan(y_nan, target), loss_spec) + begin + y_t = _get_target_y(y, target) + ŷ_t = _get_target_ŷ(ŷ, y_t, target) + _apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec) + end for target in targets ] end @@ -94,13 +117,18 @@ end function assemble_loss(ŷ, y, y_nan, targets, loss_spec::PerTarget) @assert length(targets) == length(loss_spec.losses) "Length of targets and PerTarget losses tuple must match" losses = [ - _apply_loss( - ŷ, - _get_target_y(y, target), - _get_target_nan(y_nan, target), - target, - loss_t - ) for (target, loss_t) in zip(targets, loss_spec.losses) + begin + y_t = _get_target_y(y, target) + ŷ_t = _get_target_ŷ(ŷ, y_t, target) + y_nan_t = _get_target_nan(y_nan, target) + _apply_loss( + ŷ_t, + y_t, + y_nan_t, + loss_t + ) + end + for (target, loss_t) in zip(targets, loss_spec.losses) ] return losses end @@ -117,7 +145,7 @@ function _apply_loss(ŷ, y, y_nan, loss_spec::Tuple) return loss_fn(ŷ, y, y_nan, loss_spec) end function _apply_loss(ŷ, y, y_nan, target, loss_spec) - return _apply_loss(ŷ[target], y, y_nan, loss_spec) + return _apply_loss(_get_target_ŷ(ŷ, y, target), y, y_nan, loss_spec) end """ @@ -137,9 +165,27 @@ Helper function to apply the appropriate loss function based on the specificatio function _apply_loss end _get_target_y(y, target) = y(target) -_get_target_y(y::AbstractDimArray, target) = y[col = At(target)] # assumes the DimArray uses :col indexing -_get_target_y(y::AbstractDimArray, targets::Vector) = y[col = At(targets)] # for multiple targets +_get_target_nan(y_nan, target) = y_nan(target) + +# For KeyedArray +function _get_target_y(y::KeyedArray, target) + return y(inout = target) +end + +function _get_target_y(y::KeyedArray, targets::Vector) + return y(inout = targets) +end +# For DimArray +function _get_target_y(y::AbstractDimArray, target) + return y[inout = At(target)] +end + +function _get_target_y(y::AbstractDimArray, targets::Vector) + return y[inout = At(targets)] +end + +# For Tuple (e.g. (y_obs, y_sigma)), supports KeyedArray or DimArray as y_obs function _get_target_y(y::Tuple, target) y_obs, y_sigma = y sigma = y_sigma isa Number ? y_sigma : y_sigma(target) @@ -147,16 +193,29 @@ function _get_target_y(y::Tuple, target) return (y_obs_val, sigma) end - """ _get_target_y(y, target) Helper function to extract target-specific values from `y`, handling cases where `y` can be a tuple of `(y_obs, y_sigma)`. """ function _get_target_y end -_get_target_nan(y_nan, target) = y_nan(target) -_get_target_nan(y_nan::AbstractDimArray, target) = y_nan[col = At(target)] # assumes the DimArray uses :col indexing -_get_target_nan(y_nan::AbstractDimArray, targets::Vector) = y_nan[col = At(targets)] # for multiple targets +# For KeyedArray +function _get_target_nan(y_nan::KeyedArray, target) + return y_nan(inout = target) +end + +function _get_target_nan(y_nan::KeyedArray, targets::Vector) + return y_nan(inout = targets) +end + +# For DimArray +function _get_target_nan(y_nan::AbstractDimArray, target) + return y_nan[inout = At(target)] +end + +function _get_target_nan(y_nan::AbstractDimArray, targets::Vector) + return y_nan[inout = At(targets)] +end """ _get_target_nan(y_nan, target) @@ -179,3 +238,9 @@ end function _loss_name(loss_spec::Tuple) return _loss_name(loss_spec[1]) end + +import ChainRulesCore +import AxisKeys: KeyedArray +import ChainRulesCore: ProjectTo, InplaceableThunk, unthunk + +(project::ProjectTo{KeyedArray})(dx::InplaceableThunk) = project(unthunk(dx)) diff --git a/src/utils/io.jl b/src/utils/io.jl index 78bf5b6e..9b4d2b25 100644 --- a/src/utils/io.jl +++ b/src/utils/io.jl @@ -53,12 +53,12 @@ function save_observations!(file_name, target_names, yobs, train_or_val_name) end function to_named_tuple(ka, target_names) - arrays = [Array(ka(k)) for k in target_names] + arrays = [Array(ka(inout = k)) for k in target_names] return NamedTuple{Tuple(target_names)}(arrays) end function to_named_tuple(ka::AbstractDimArray, target_names) - arrays = [Array(ka[col = At(k)]) for k in target_names] + arrays = [Array(ka[inout = At(k)]) for k in target_names] return NamedTuple{Tuple(target_names)}(arrays) end diff --git a/src/utils/tools.jl b/src/utils/tools.jl index ea99f66d..35e35e3e 100644 --- a/src/utils/tools.jl +++ b/src/utils/tools.jl @@ -1,5 +1,5 @@ #### Data handling -export select_predictors, to_keyedArray, split_data +export select_predictors, to_keyedArray, to_dimArray, split_data export toDataFrame, toNamedTuple, toArray # Make vec each entry of NamedTuple (since broadcast ist reserved) @@ -52,7 +52,16 @@ tokeyedArray(df::DataFrame) """ function to_keyedArray(df::DataFrame) d = Matrix(df) |> transpose - return KeyedArray(d, row = Symbol.(names(df)), col = 1:size(d, 2)) + return KeyedArray(d, inout = Symbol.(names(df)), batch_size = 1:size(d, 2)) +end + +# Convert a DataFrame to a DimArray where variables are in 1st dim (rows) +""" +to_dimArray(df::DataFrame) +""" +function to_dimArray(df::DataFrame) + d = Matrix(df) |> transpose |> Array + return DimArray(d, (Dim{:inout}(Symbol.(names(df))), Dim{:batch_size}(1:size(d, 2)))) end # Cast a grouped dataframe into a KeyedArray, where the group is the third dimension @@ -102,38 +111,128 @@ split_data(df::DataFrame, target, xvars, seqID; f=0.8, batchsize=32, shuffle=tru function split_data(df::DataFrame, target, xvars, seqID; f = 0.8, batchsize = 32, shuffle = true, partial = true) dfg = groupby(df, seqID) dkg = to_keyedArray(dfg) - #@show axiskeys(dkg)[1] # Do the partitioning via indices of the 3rd dimension (e.g. seqID) because # partition does not allow partitioning along that dimension (or even not arrays at all) idx_tr, idx_vali = partition(axiskeys(dkg)[3], f; shuffle) # wrap training data into Flux.DataLoader - x = dkg(row = xvars, seqID = idx_tr) - y = dkg(row = target, seqID = idx_tr) |> Array + x = dkg(inout = xvars, seqID = idx_tr) + y = dkg(inout = target, seqID = idx_tr) |> Array data_t = (; x, y) trainloader = Flux.DataLoader(data_t; batchsize, shuffle, partial) trainall = Flux.DataLoader(data_t; batchsize = size(x, 3), shuffle = false, partial = false) # wrap validation data into Flux.DataLoader - x = dkg(row = xvars, seqID = idx_vali) - y = dkg(row = target, seqID = idx_vali) |> Array + x = dkg(inout = xvars, seqID = idx_vali) + y = dkg(inout = target, seqID = idx_vali) |> Array data_v = (; x, y) valloader = Flux.DataLoader(data_v; batchsize = size(x, 3), shuffle = false, partial = false) return trainloader, valloader, trainall end -function toDataFrame(ka) - data_array = Array(ka') - df = DataFrame(data_array, ka.row) - df.index = ka.col - return df +using AxisKeys +using NamedDims: NamedDims # Required for NamedDims.dim with KeyedArrays +using DataFrames +using DimensionalData: DimensionalData, AbstractDimArray, Dim, DimArray, dims, lookup, At + +_key_to_colname(k) = k isa Symbol ? k : Symbol(string(k)) + +# Helper to get dimension index from dimension name (works for both KeyedArray and DimArray) +_dim_index(ka::KeyedArray, name::Symbol) = NamedDims.dim(ka, name) +function _dim_index(da::AbstractDimArray, name::Symbol) + dim_names = DimensionalData.name.(dims(da)) + idx = findfirst(==(name), dim_names) + isnothing(idx) && throw(ArgumentError("Dimension :$name not found in array with dimensions $dim_names")) + return idx end -function toDataFrame(ka::AbstractDimArray) - data_array = Array(ka') - df = DataFrame(data_array, Array(dims(ka, :col))) - df.index = Array(dims(ka, :row)) +# Helper to extract raw array data (works for both KeyedArray and DimArray) +_raw_array(ka::KeyedArray) = Array(AxisKeys.keyless(ka)) +_raw_array(da::AbstractDimArray) = Array(parent(da)) + +# Helper to select a single value along a named dimension +_select_at(ka::KeyedArray, dim_name::Symbol, key) = ka(; NamedTuple{(dim_name,)}((key,))...) +_select_at(da::AbstractDimArray, dim_name::Symbol, key) = view(da, Dim{dim_name}(At(key))) + +# 2D Labeled Array -> DataFrame (works for both KeyedArray and DimArray) +""" + toDataFrame(arr::Union{KeyedArray{T, 2}, AbstractDimArray{T, 2}}, cols_dim=:inout, index_dim=:batch_size; index_col=:index) + +Convert a 2D labeled array (KeyedArray or DimArray) to a DataFrame. + +# Arguments +- `arr`: The 2D labeled array to convert +- `cols_dim`: Dimension name to use as DataFrame columns (default: `:inout`) +- `index_dim`: Dimension name to use as DataFrame row index (default: `:batch_size`) +- `index_col`: Name for the index column in the result (default: `:index`) + +# Returns +- `DataFrame` with columns from `cols_dim` keys and an index column from `index_dim` keys +""" +function toDataFrame( + arr::Union{KeyedArray{T, 2}, AbstractDimArray{T, 2}}, + cols_dim::Symbol = :inout, + index_dim::Symbol = :batch_size; + index_col::Symbol = :index, + ) where {T} + + dcols = _dim_index(arr, cols_dim) + didx = _dim_index(arr, index_dim) + + # Reorder so rows=index_dim, cols=cols_dim (i.e., didx=1, dcols=2) + arr2 = (didx == 1 && dcols == 2) ? arr : permutedims(arr, (didx, dcols)) + + data = _raw_array(arr2) + col_names = _key_to_colname.(collect(axiskeys(arr2, 2))) + + df = DataFrame(data, col_names; makeunique = true) + df[!, index_col] = collect(axiskeys(arr2, 1)) return df end +# 3D Labeled Array -> Dict(slice_key => DataFrame) +""" + toDataFrame(arr::AbstractLabeledArray{T, 3}, cols_dim=:inout, index_dim=:batch_size; slice_dim=:time, index_col=:index) + +Convert a 3D labeled array (KeyedArray or DimArray) to a Dict of DataFrames, one per slice. + +# Arguments +- `arr`: The 3D labeled array to convert +- `cols_dim`: Dimension name to use as DataFrame columns (default: `:inout`) +- `index_dim`: Dimension name to use as DataFrame row index (default: `:batch_size`) +- `slice_dim`: Dimension name to slice along (default: `:time`) +- `index_col`: Name for the index column in each result DataFrame (default: `:index`) + +# Returns +- `Dict{Any, DataFrame}` mapping slice keys to DataFrames +""" +function toDataFrame( + arr::Union{KeyedArray{T, 3}, AbstractDimArray{T, 3}}, + cols_dim::Symbol = :inout, + index_dim::Symbol = :batch_size; + slice_dim::Symbol = :time, + index_col::Symbol = :index, + ) where {T} + + out = Dict{Any, DataFrame}() + for k in axiskeys(arr, slice_dim) + slice = _select_at(arr, slice_dim, k) + out[k] = toDataFrame(slice, cols_dim, index_dim; index_col = index_col) + end + return out +end + +# Convenience: extract specific targets from a labeled array into a DataFrame +""" + toDataFrame(arr, target_names) + +Extract specific target variables from a labeled array into a DataFrame with `_pred` suffix. + +# Arguments +- `arr`: A labeled array or NamedTuple-like object with property access +- `target_names`: Vector of target variable names to extract + +# Returns +- `DataFrame` with columns named `_pred` for each target +""" function toDataFrame(ka, target_names) data = [getproperty(ka, t_name) for t_name in target_names] @@ -148,15 +247,16 @@ function toDataFrame(ka, target_names) end # ============================================================================= -# KeyedArray unpacking functions +# Array unpacking functions (works for both KeyedArray and DimArray) # ============================================================================= """ -toNamedTuple(ka::KeyedArray, variables::Vector{Symbol}) -Extract specified variables from a KeyedArray and return them as a NamedTuple of vectors. + toNamedTuple(ka::Union{KeyedArray, AbstractDimArray}, variables::Vector{Symbol}) + +Extract specified variables from a KeyedArray or DimArray and return them as a NamedTuple of vectors. # Arguments: -- `ka`: The KeyedArray to unpack +- `ka`: The KeyedArray or DimArray to unpack - `variables`: Vector of symbols representing the variables to extract # Returns: @@ -164,29 +264,29 @@ Extract specified variables from a KeyedArray and return them as a NamedTuple of # Example: ```julia -# Extract SW_IN and TA from a KeyedArray -data = toNamedTuple(ds_keyed, [:SW_IN, :TA]) +# Extract SW_IN and TA from an array +data = toNamedTuple(ds, [:SW_IN, :TA]) sw_in = data.SW_IN ta = data.TA ``` """ function toNamedTuple(ka::KeyedArray, variables::Vector{Symbol}) - vals = [vec(ka([var])) for var in variables] + vals = [dropdims(ka(inout = [var]), dims = 1) for var in variables] return (; zip(variables, vals)...) end function toNamedTuple(ka::AbstractDimArray, variables::Vector{Symbol}) - vals = [vec(ka[col = At(var)]) for var in variables] + vals = [dropdims(ka[inout = At([var])], dims = 1) for var in variables] return (; zip(variables, vals)...) end function toNamedTuple(ka::KeyedArray, variables::NTuple{N, Symbol}) where {N} - vals = ntuple(i -> vec(ka([variables[i]])), N) + vals = ntuple(i -> ka(inout = [variables[i]]), N) return NamedTuple{variables}(vals) end function toNamedTuple(ka::AbstractDimArray, variables::NTuple{N, Symbol}) where {N} - ntuple(i -> vec(ka[col = At([variables[i]])]), N) + ntuple(i -> ka[inout = At([variables[i]])], N) return NamedTuple{variables}(vals) end @@ -202,8 +302,8 @@ Extract all variables from a KeyedArray and return them as a NamedTuple of vecto # Example: ```julia -# Extract all variables from a KeyedArray -data = toNamedTuple(ds_keyed) +# Extract all variables from an array +data = toNamedTuple(ds) # Access individual variables sw_in = data.SW_IN ta = data.TA @@ -211,12 +311,12 @@ nee = data.NEE ``` """ function toNamedTuple(ka::KeyedArray) - variables = Symbol.(axiskeys(ka)[1]) # Get all variable names from first dimension + variables = Symbol.(axiskeys(ka, :inout)) # Get all variable names from :inout dimension return toNamedTuple(ka, variables) end function toNamedTuple(ka::AbstractDimArray) - variables = Symbol.(dims(A, :col)) # Get all variable names from first dimension + variables = Symbol.(lookup(ka, :inout)) # Get all variable names from :inout dimension return toNamedTuple(ka, variables) end @@ -225,7 +325,7 @@ toNamedTuple(ka::KeyedArray, variable::Symbol) Extract a single variable from a KeyedArray and return it as a vector. # Arguments: -- `ka`: The KeyedArray to unpack +- `ka`: The KeyedArray or DimArray to unpack - `variable`: Symbol representing the variable to extract # Returns: @@ -233,22 +333,22 @@ Extract a single variable from a KeyedArray and return it as a vector. # Example: ```julia -# Extract just SW_IN from a KeyedArray -sw_in = toNamedTuple(ds_keyed, :SW_IN) +# Extract just SW_IN from an array +sw_in = toNamedTuple(ds, :SW_IN) ``` """ function toNamedTuple(ka::KeyedArray, variable::Symbol) - return vec(ka[variable]) + return ka(inout = variable) end function toNamedTuple(ka::AbstractDimArray, variable::Symbol) - return vec(ka[col = At(variable)]) + return ka[inout = At(variable)] end function toArray(ka::KeyedArray, variable) - return ka(variable) + return ka(inout = variable) end function toArray(ka::AbstractDimArray, variable) - return ka[col = At(variable)] + return ka[inout = At(variable)] end diff --git a/test/runtests.jl b/test/runtests.jl index a772ca70..9461451d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,6 @@ +#using Pkg +#Pkg.activate(@__DIR__) +#using Revise using EasyHybrid using Test # import Flux diff --git a/test/test_compute_loss.jl b/test/test_compute_loss.jl index fb87339f..92cbac2e 100644 --- a/test/test_compute_loss.jl +++ b/test/test_compute_loss.jl @@ -72,8 +72,8 @@ using DataFrames :var1 => DimArray([1.0, 2.0, 3.0], (Ti(1:3),)), :var2 => DimArray([2.0, 3.0, 4.0], (Ti(1:3),)) ) - y_dim = DimArray([1.1 1.8; 1.9 3.1; 3.2 3.9], (Ti(1:3), Dim{:col}([:var1, :var2]))) - y_nan_dim = DimArray(trues(3, 2), (Ti(1:3), Dim{:col}([:var1, :var2]))) + y_dim = DimArray([1.1 1.8; 1.9 3.1; 3.2 3.9], (Ti(1:3), Dim{:inout}([:var1, :var2]))) + y_nan_dim = DimArray(trues(3, 2), (Ti(1:3), Dim{:inout}([:var1, :var2]))) # Test single predefined loss loss = _compute_loss(ŷ_dim, y_dim, y_nan_dim, targets, :mse, sum) @@ -121,12 +121,12 @@ end @test _get_target_nan(y_nan_func, :var2) == [true, true, false] # Test with AbstractDimArray - y_nan_dim = DimArray([true false; true true; false true], (Ti(1:3), Dim{:col}([:var1, :var2]))) + y_nan_dim = DimArray([true false; true true; false true], (Ti(1:3), Dim{:inout}([:var1, :var2]))) @test _get_target_nan(y_nan_dim, :var1) == [true, true, false] @test _get_target_nan(y_nan_dim, :var2) == [false, true, true] # Test with Vector of targets - y_nan_dim_multi = DimArray([true false; true true; false true], (Ti(1:3), Dim{:col}([:var1, :var2]))) + y_nan_dim_multi = DimArray([true false; true true; false true], (Ti(1:3), Dim{:inout}([:var1, :var2]))) result = _get_target_nan(y_nan_dim_multi, [:var1, :var2]) @test size(result) == (3, 2) @test result[:, 1] == [true, true, false] @@ -140,12 +140,12 @@ end @test _get_target_y(y_func, :var2) == [2.0, 3.0, 4.0] # Test with AbstractDimArray - y_dim = DimArray([1.0 2.0; 2.0 3.0; 3.0 4.0], (Ti(1:3), Dim{:col}([:var1, :var2]))) + y_dim = DimArray([1.0 2.0; 2.0 3.0; 3.0 4.0], (Ti(1:3), Dim{:inout}([:var1, :var2]))) @test _get_target_y(y_dim, :var1) == [1.0, 2.0, 3.0] @test _get_target_y(y_dim, :var2) == [2.0, 3.0, 4.0] # Test with Vector of targets - y_dim_multi = DimArray([1.0 2.0; 2.0 3.0; 3.0 4.0], (Ti(1:3), Dim{:col}([:var1, :var2]))) + y_dim_multi = DimArray([1.0 2.0; 2.0 3.0; 3.0 4.0], (Ti(1:3), Dim{:inout}([:var1, :var2]))) result = _get_target_y(y_dim_multi, [:var1, :var2]) @test size(result) == (3, 2) @test result[:, 1] == [1.0, 2.0, 3.0] diff --git a/test/test_generic_hybrid_model.jl b/test/test_generic_hybrid_model.jl index 55aac728..c0a4f502 100644 --- a/test/test_generic_hybrid_model.jl +++ b/test/test_generic_hybrid_model.jl @@ -176,7 +176,7 @@ end @test model isa SingleNNHybridModel @test model.predictors == predictors @test model.NN isa Chain - @test length(model.NN.layers) == 0 # Empty chain + @test typeof(model.NN.layers[1]) == Lux.NoOpLayer # Empty chain end @testset "SingleNNHybridModel initialparameters" begin @@ -496,7 +496,7 @@ end st = LuxCore.initialstates(rng, model) @test haskey(ps, :ps) # Even with empty NN, ps key exists (may be empty) - @test isempty(ps.ps) + @test isempty(ps.ps[1]) output, new_st = model(dk, ps, st) @test haskey(output, :y_pred) diff --git a/test/test_split_data_train.jl b/test/test_split_data_train.jl index da42715b..8ee45846 100644 --- a/test/test_split_data_train.jl +++ b/test/test_split_data_train.jl @@ -121,7 +121,7 @@ const RbQ10_PARAMS = ( @test !isnothing(out) mat = vcat(ka[1], ka[2]) - da = DimArray(mat, (Dim{:col}(mat.keys[1]), Dim{:row}(1:size(mat, 2))))' + da = DimArray(mat, (Dim{:inout}(mat.keys[1]), Dim{:batch_size}(1:size(mat, 2))))' ka = prepare_data(model, da) @test !isnothing(ka)