From ec5851123ecb8c56afce3f2de5721118879f6f9d Mon Sep 17 00:00:00 2001 From: Lazaro Alonso Date: Fri, 26 Sep 2025 17:51:39 +0200 Subject: [PATCH 1/4] yax --- projects/RbQ10/Project.toml | 1 + projects/RbQ10/Q10_dd.jl | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/projects/RbQ10/Project.toml b/projects/RbQ10/Project.toml index d1f9c50b..7d90ab46 100644 --- a/projects/RbQ10/Project.toml +++ b/projects/RbQ10/Project.toml @@ -3,4 +3,5 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3" GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" +YAXArrays = "c21b50f5-aa40-41ea-b809-c0f5e47bfa5c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/projects/RbQ10/Q10_dd.jl b/projects/RbQ10/Q10_dd.jl index 23bcaf36..94dc06fd 100644 --- a/projects/RbQ10/Q10_dd.jl +++ b/projects/RbQ10/Q10_dd.jl @@ -91,7 +91,7 @@ ar = rand(3,3) A = DimArray(ar, (Y([:a,:b,:c]), X(1:3))); grad = Zygote.gradient(x -> sum(x[Y=At(:b)]), A) -xy = EasyHybrid.split_data((ds_p_f, ds_t), 0.8, shuffle=true, rng=Random.default_rng()) +# xy = EasyHybrid.split_data((ds_p_f, ds_t), 0.8, shuffle=true, rng=Random.default_rng()) EasyHybrid.get_prediction_target_names(RbQ10) @@ -100,3 +100,10 @@ xy1 = EasyHybrid.prepare_data(RbQ10, da) (x_train, y_train), (x_val, y_val) = EasyHybrid.split_data(da, RbQ10) # ; shuffleobs=false, split_data_at=0.8 out = train(RbQ10, da, (:Q10, ); nepochs=200, batchsize=512, opt=Adam(0.01)); + +using YAXArrays +axDims = dims(da) + +ds_yax = YAXArray(axDims, da.data) + +out_yax = train(RbQ10, ds_yax, (:Q10, ); nepochs=200, batchsize=512, opt=Adam(0.01)); From adab80a9afc1cee4664974859e5b12c9de8d2299 Mon Sep 17 00:00:00 2001 From: Bernhard Ahrens Date: Sat, 27 Sep 2025 22:54:19 +0200 Subject: [PATCH 2/4] YaxArray isnan handle and boradcast, map everything? --- projects/RbQ10/Q10_dd.jl | 12 +++++++++++- src/utils/loss_fn.jl | 12 +++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/projects/RbQ10/Q10_dd.jl b/projects/RbQ10/Q10_dd.jl index 94dc06fd..61e2e887 100644 --- a/projects/RbQ10/Q10_dd.jl +++ b/projects/RbQ10/Q10_dd.jl @@ -37,7 +37,7 @@ grads = backtrace(l)[1] # TODO: test DimArray inputs using DimensionalData, ChainRulesCore # mat = Matrix(df)' -mat = Array(Matrix(df)') +mat = Float32.(Array(Matrix(df)')) da = DimArray(mat, (Dim{:col}(Symbol.(names(df))), Dim{:row}(1:size(df,1)))) ##! new dispatch @@ -106,4 +106,14 @@ axDims = dims(da) ds_yax = YAXArray(axDims, da.data) +ds_p_f = ds_yax[col=At(forcing_names ∪ predictor_names)] +ds_t = ds_yax[col=At(target_names)] +ds_t_nan = .!isnan.(ds_t) # produces 1×35064 YAXArray{Float32, 2}, not a Bool +ds_t_nan = map(x -> !isnan(x), ds_t) # 1×35064 YAXArray{Bool, 2} + +ls = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss()) +ls_logs = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss(train_mode=false)) +acc_ = EasyHybrid.evaluate_acc(RbQ10, ds_p_f, ds_t, ds_t_nan, ps, st, [:mse, :r2], :mse, sum) + +# TODO how to proceed - would it work already for multiple targets? out_yax = train(RbQ10, ds_yax, (:Q10, ); nepochs=200, batchsize=512, opt=Adam(0.01)); diff --git a/src/utils/loss_fn.jl b/src/utils/loss_fn.jl index 994e2980..061cc336 100644 --- a/src/utils/loss_fn.jl +++ b/src/utils/loss_fn.jl @@ -32,7 +32,17 @@ function loss_fn(ŷ, y, y_nan, ::Val{:rmse}) return sqrt(mean(abs2, (ŷ[y_nan] .- y[y_nan]))) end function loss_fn(ŷ, y, y_nan, ::Val{:mse}) - return mean(abs2, (ŷ[y_nan] .- y[y_nan])) + # Option 1: Convert to Array and compute MSE + #yh = Array(ŷ[y_nan]) + #yt = Array(y[y_nan]) + #return mean(abs2, yh .- yt) + + # Option 2: Use YAXArray directly but map has to be used + return mean(x -> x, map((a,b)->(a-b)^2, ŷ[y_nan], y[y_nan])) + + # Option 3 gives an error + #return mean(abs2, (ŷ[y_nan] .- y[y_nan])) # errors with ERROR: MethodError: no method matching to_yax(::Vector{Float32}) The function `to_yax` exists, but no method is defined for this combination of argument types. + # I guess our model output would need to yax and not Vector{Float32} end function loss_fn(ŷ, y, y_nan, ::Val{:mae}) return mean(abs, (ŷ[y_nan] .- y[y_nan])) From 1d5b882353679b2e6698b545d5d667385676566f Mon Sep 17 00:00:00 2001 From: Lazaro Alonso Date: Mon, 13 Oct 2025 16:16:38 +0200 Subject: [PATCH 3/4] run yax --- projects/RbQ10/Q10_dd.jl | 3 +++ src/train.jl | 10 +++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/projects/RbQ10/Q10_dd.jl b/projects/RbQ10/Q10_dd.jl index 61e2e887..7f7954c9 100644 --- a/projects/RbQ10/Q10_dd.jl +++ b/projects/RbQ10/Q10_dd.jl @@ -110,6 +110,9 @@ ds_p_f = ds_yax[col=At(forcing_names ∪ predictor_names)] ds_t = ds_yax[col=At(target_names)] ds_t_nan = .!isnan.(ds_t) # produces 1×35064 YAXArray{Float32, 2}, not a Bool ds_t_nan = map(x -> !isnan(x), ds_t) # 1×35064 YAXArray{Bool, 2} +length(ds_t_nan) +# is_no_nan = .!isnan.(y) + ls = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss()) ls_logs = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss(train_mode=false)) diff --git a/src/train.jl b/src/train.jl index 6e000c57..22f55684 100644 --- a/src/train.jl +++ b/src/train.jl @@ -122,8 +122,10 @@ function train(hybridModel, data, save_ps; opt_state = Optimisers.setup(opt, ps) # ? initial losses - is_no_nan_t = .!isnan.(y_train) - is_no_nan_v = .!isnan.(y_val) + # is_no_nan_t = .!isnan.(y_train) + is_no_nan_t = map(x -> !isnan(x), y_train) + # is_no_nan_v = .!isnan.(y_val) + is_no_nan_v = map(x -> !isnan(x), y_val) l_init_train, _, init_ŷ_train = evaluate_acc(hybridModel, x_train, y_train, is_no_nan_t, ps, st, loss_types, training_loss, agg) l_init_val, _, init_ŷ_val = evaluate_acc(hybridModel, x_val, y_val, is_no_nan_v, ps, st, loss_types, training_loss, agg) @@ -194,7 +196,9 @@ function train(hybridModel, data, save_ps; for epoch in 1:nepochs for (x, y) in train_loader # ? check NaN indices before going forward, and pass filtered `x, y`. - is_no_nan = .!isnan.(y) + # is_no_nan = .!isnan.(y) + is_no_nan = map(x -> !isnan(x), y) # doing this due to YAXArray Bool issue + if length(is_no_nan)>0 # ! be careful here, multivariate needs fine tuning l, backtrace = Zygote.pullback((ps) -> lossfn(hybridModel, x, (y, is_no_nan), ps, st, LoggingLoss(training_loss=training_loss, agg=agg)), ps) From cb293ad07dd3532aff0bcbd51c0d59ab9135aee8 Mon Sep 17 00:00:00 2001 From: Lazaro Alonso Date: Mon, 27 Oct 2025 09:37:47 +0100 Subject: [PATCH 4/4] test main --- projects/RbQ10/Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/projects/RbQ10/Project.toml b/projects/RbQ10/Project.toml index 7d90ab46..fab7f5e8 100644 --- a/projects/RbQ10/Project.toml +++ b/projects/RbQ10/Project.toml @@ -5,3 +5,6 @@ EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3" GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" YAXArrays = "c21b50f5-aa40-41ea-b809-c0f5e47bfa5c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[sources] +YAXArrays = {rev = "main", url = "https://github.com/JuliaDataCubes/YAXArrays.jl"}