Skip to content

Commit a605ec1

Browse files
allow for tuples
1 parent 3083d48 commit a605ec1

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

src/solutions/ode_solutions.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@ function build_solution(
2727
N = length((size(prob.u0)..., length(u)))
2828
end
2929

30-
if has_analytic(prob.f)
30+
if typeof(prob.f) <: Tuple
31+
f = prob.f[1]
32+
else
33+
f = prob.f
34+
end
35+
36+
if has_analytic(f)
3137
u_analytic = Vector{typeof(prob.u0)}(0)
3238
errors = Dict{Symbol,eltype(prob.u0)}()
3339
sol = ODESolution{T,N,typeof(u),typeof(u_analytic),typeof(errors),typeof(t),typeof(k),
@@ -45,9 +51,16 @@ function build_solution(
4551
end
4652

4753
function calculate_solution_errors!(sol::AbstractODESolution;fill_uanalytic=true,timeseries_errors=true,dense_errors=true)
54+
55+
if typeof(sol.prob.f) <: Tuple
56+
f = sol.prob.f[1]
57+
else
58+
f = sol.prob.f
59+
end
60+
4861
if fill_uanalytic
4962
for i in 1:size(sol.u,1)
50-
push!(sol.u_analytic,sol.prob.f(Val{:analytic},sol.t[i],sol.prob.u0))
63+
push!(sol.u_analytic,f(Val{:analytic},sol.t[i],sol.prob.u0))
5164
end
5265
end
5366

@@ -61,7 +74,7 @@ function calculate_solution_errors!(sol::AbstractODESolution;fill_uanalytic=true
6174
if sol.dense && dense_errors
6275
densetimes = collect(linspace(sol.t[1],sol.t[end],100))
6376
interp_u = sol(densetimes)
64-
interp_analytic = [sol.prob.f(Val{:analytic},t,sol.u[1]) for t in densetimes]
77+
interp_analytic = [f(Val{:analytic},t,sol.u[1]) for t in densetimes]
6578
sol.errors[:L∞] = maximum(vecvecapply((x)->abs.(x),interp_u-interp_analytic))
6679
sol.errors[:L2] = sqrt(recursive_mean(vecvecapply((x)->float.(x).^2,interp_u-interp_analytic)))
6780
end

src/solutions/rode_solutions.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@ function build_solution(
2929
N = length((size(prob.u0)..., length(u)))
3030
end
3131

32-
if has_analytic(prob.f)
32+
if typeof(prob.f) <: Tuple
33+
f = prob.f[1]
34+
else
35+
f = prob.f
36+
end
37+
38+
if has_analytic(f)
3339
u_analytic = Vector{typeof(prob.u0)}(0)
3440
errors = Dict{Symbol,eltype(prob.u0)}()
3541
sol = RODESolution{T,N,typeof(u),typeof(u_analytic),typeof(errors),typeof(t),typeof(W),
@@ -49,9 +55,16 @@ function build_solution(
4955
end
5056

5157
function calculate_solution_errors!(sol::AbstractRODESolution;fill_uanalytic=true,timeseries_errors=true,dense_errors=true)
58+
59+
if typeof(sol.prob.f) <: Tuple
60+
f = sol.prob.f[1]
61+
else
62+
f = sol.prob.f
63+
end
64+
5265
if fill_uanalytic
5366
for i in 1:size(sol.u,1)
54-
push!(sol.u_analytic,sol.prob.f(Val{:analytic},sol.t[i],sol.prob.u0,sol.W[i]))
67+
push!(sol.u_analytic,f(Val{:analytic},sol.t[i],sol.prob.u0,sol.W[i]))
5568
end
5669
end
5770

0 commit comments

Comments
 (0)