Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6f83b93
kernels setup done
sivasathyaseeelan Jul 23, 2025
2320b9a
ImplicitTauLeaping setup done for jump problem solver
sivasathyaseeelan Jul 24, 2025
4e964ec
jump_problem solver fixed
sivasathyaseeelan Jul 28, 2025
f978ea9
removed equilibrium_pair logic
sivasathyaseeelan Jul 31, 2025
87222ff
tranculation error fixed
sivasathyaseeelan Jul 31, 2025
7a98d71
nonlinearsolver is implemented
sivasathyaseeelan Aug 1, 2025
9994a08
changed to SimpleImplicitTauLeaping
sivasathyaseeelan Aug 9, 2025
0eb637a
refactor
sivasathyaseeelan Aug 9, 2025
8973805
SimpleAdaptiveTauLeaping is done
sivasathyaseeelan Aug 9, 2025
741431f
simple version of SimpleImplicitTauLeaping
sivasathyaseeelan Aug 19, 2025
b5226f7
removed adaptive tau leap
sivasathyaseeelan Aug 19, 2025
6f868df
poiss change
sivasathyaseeelan Aug 19, 2025
2e5d82c
changed to inline non linear solver
sivasathyaseeelan Aug 19, 2025
f7ffa4d
refactor
sivasathyaseeelan Aug 19, 2025
439bb7d
typo
sivasathyaseeelan Aug 19, 2025
bbe9dc5
basic version of inplicit tau leap is done
sivasathyaseeelan Aug 19, 2025
5aa08e4
added critical_threshold
sivasathyaseeelan Aug 20, 2025
7f0c960
residual update
sivasathyaseeelan Aug 20, 2025
13711d2
added comment line
sivasathyaseeelan Aug 20, 2025
afbe6dd
SimpleImplicitTauLeaping
sivasathyaseeelan Sep 5, 2025
0f7a40c
project.toml
sivasathyaseeelan Sep 5, 2025
f892dc8
project.toml
sivasathyaseeelan Sep 5, 2025
90f66af
some
sivasathyaseeelan Sep 5, 2025
4cce2a9
some
sivasathyaseeelan Sep 5, 2025
0b364f0
test update
sivasathyaseeelan Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Expand Down
4 changes: 3 additions & 1 deletion src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ using Base.Threads: Threads, @threads
using Base.FastMath: add_fast
using Setfield: @set, @set!

using SimpleNonlinearSolve

# Import functions we extend from Base
import Base: size, getindex, setindex!, length, similar, show, merge!, merge

Expand Down Expand Up @@ -129,7 +131,7 @@ export SSAStepper

# leaping:
include("simple_regular_solve.jl")
export SimpleTauLeaping, EnsembleGPUKernel
export SimpleTauLeaping, SimpleImplicitTauLeaping, NewtonImplicitSolver, TrapezoidalImplicitSolver, EnsembleGPUKernel

# spatial:
include("spatial/spatial_massaction_jump.jl")
Expand Down
255 changes: 255 additions & 0 deletions src/simple_regular_solve.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end

# Define solver type hierarchy
abstract type AbstractImplicitSolver end
struct NewtonImplicitSolver <: AbstractImplicitSolver end
struct TrapezoidalImplicitSolver <: AbstractImplicitSolver end

struct SimpleImplicitTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm
epsilon::T # Error control parameter for tau selection
solver::AbstractImplicitSolver # Solver type: Newton or Trapezoidal
end

SimpleImplicitTauLeaping(; epsilon=0.05, solver=NewtonImplicitSolver()) = SimpleImplicitTauLeaping(epsilon, solver)

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
Expand All @@ -14,6 +26,19 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
jump_prob.regular_jump !== nothing
end

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \
Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release."
end
isempty(jump_prob.jump_callback.continuous_callbacks) &&
isempty(jump_prob.jump_callback.discrete_callbacks) &&
isempty(jump_prob.constant_jumps) &&
isempty(jump_prob.variable_jumps) &&
jump_prob.massaction_jump !== nothing
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
seed = nothing, dt = error("dt is required for SimpleTauLeaping."))
validate_pure_leaping_inputs(jump_prob, alg) ||
Expand Down Expand Up @@ -61,6 +86,236 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
interp = DiffEqBase.ConstantInterpolation(t, u))
end

function compute_hor(reactant_stoch, numjumps)
# Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV.
# HOR is the sum of stoichiometric coefficients of reactants in reaction j.
hor = zeros(Int, numjumps)
for j in 1:numjumps
order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0)
if order > 3
error("Reaction $j has order $order, which is not supported (maximum order is 3).")
end
hor[j] = order
end
return hor
end

function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps)
# Precompute reaction conditions for each species i, including:
# - max_hor: the highest order of reaction (HOR) where species i is a reactant.
# - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor.
# Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27).
max_hor = zeros(Int, numspecies)
max_stoich = zeros(Int, numspecies)
for j in 1:numjumps
for (spec_idx, stoch) in reactant_stoch[j]
if stoch > 0 # Species is a reactant
if hor[j] > max_hor[spec_idx]
max_hor[spec_idx] = hor[j]
max_stoich[spec_idx] = stoch
elseif hor[j] == max_hor[spec_idx]
max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch)
end
end
end
end
return max_hor, max_stoich
end

function compute_gi(u, max_hor, max_stoich, i, t)
# Compute g_i for species i to bound the relative change in propensity functions,
# as per Cao et al. (2006), Section IV, equation (27).
# g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant:
# - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1
# - HOR = 2 (second-order):
# - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2
# - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1)
# - HOR = 3 (third-order):
# - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3
# - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1))
# - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2)
# Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep.
if max_hor[i] == 0 # No reactions involve species i as a reactant
return 1.0
elseif max_hor[i] == 1
return 1.0
elseif max_hor[i] == 2
if max_stoich[i] == 1
return 2.0
elseif max_stoich[i] == 2
return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 # Fallback to 2.0 if x_i <= 1
end
elseif max_hor[i] == 3
if max_stoich[i] == 1
return 3.0
elseif max_stoich[i] == 2
return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 # Fallback to 3.0 if x_i <= 1
elseif max_stoich[i] == 3
return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 # Fallback to 3.0 if x_i <= 2
end
end
return 1.0 # Default case
end

function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
# Compute the tau-leaping step-size using equation (20) from Cao et al. (2006):
# tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) }
# where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b):
# mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x)
# I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified).
rate(rate_cache, u, p, t)
if all(==(0.0), rate_cache) # Handle case where all rates are zero
return dtmin
end
tau = Inf
for i in 1:length(u)
mu = zero(eltype(u))
sigma2 = zero(eltype(u))
for j in 1:size(nu, 2)
mu += nu[i, j] * rate_cache[j] # Equation (9a)
sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b)
end
gi = compute_gi(u, max_hor, max_stoich, i, t)
bound = max(epsilon * u[i] / gi, 1.0) # max(epsilon * x_i / g_i, 1)
mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8)
sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8)
tau = min(tau, mu_term, sigma_term) # Equation (8)
end
return max(tau, dtmin)
end

# Define residual for implicit equation
# Newton: u_new = u_current + sum_j nu_j * a_j(u_new) * tau (Cao et al., 2004)
# Trapezoidal: u_new = u_current + sum_j nu_j * (a_j(u_current) + a_j(u_new))/2 * tau
function implicit_equation!(resid, u_new, params)
u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params
rate(rate_cache, u_new, p, t + tau)
resid .= u_new .- u_current
for j in 1:numjumps
for spec_idx in 1:size(nu, 1)
if isa(solver, NewtonImplicitSolver)
resid[spec_idx] -= nu[spec_idx, j] * rate_cache[j] * tau # Cao et al. (2004)
else # TrapezoidalImplicitSolver
rate_current = similar(rate_cache)
rate(rate_current, u_current, p, t)
resid[spec_idx] -= nu[spec_idx, j] * 0.5 * (rate_cache[j] + rate_current[j]) * tau
end
end
end
resid .= max.(resid, -u_new) # Ensure non-negative solution
end

# Solve implicit equation using SimpleNonlinearSolve
function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)
u_new = convert(Vector{Float64}, u_current)
prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver))
sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6)
return sol.u, sol.retcode == ReturnCode.Success
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
seed = nothing,
dtmin = 1e-10,
saveat = nothing)
validate_pure_leaping_inputs(jump_prob, alg) ||
error("SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.")

@unpack prob, rng = jump_prob
(seed !== nothing) && seed!(rng, seed)

maj = jump_prob.massaction_jump
numjumps = get_num_majumps(maj)
rj = jump_prob.regular_jump
# Extract rates
rate = rj !== nothing ? rj.rate :
(out, u, p, t) -> begin
for j in 1:numjumps
out[j] = evalrxrate(u, j, maj)
end
end
c = rj !== nothing ? rj.c : nothing
u0 = copy(prob.u0)
tspan = prob.tspan
p = prob.p

u_current = copy(u0)
t_current = tspan[1]
usave = [copy(u0)]
tsave = [tspan[1]]
rate_cache = zeros(Float64, numjumps)
counts = zeros(Int64, numjumps)
du = similar(u0)
t_end = tspan[2]
epsilon = alg.epsilon
solver = alg.solver

nu = zeros(Int64, length(u0), numjumps)
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
nu[spec_idx, j] = stoch
end
end
reactant_stoch = maj.reactant_stoch
hor = compute_hor(reactant_stoch, numjumps)
max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps)

saveat_times = isnothing(saveat) ? Vector{Float64}() :
(saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat))
save_idx = 1

while t_current < t_end
rate(rate_cache, u_current, p, t_current)
tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
tau = min(tau, t_end - t_current)
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx]
tau = saveat_times[save_idx] - t_current
end

u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p, t_current, tau, rate, numjumps, solver)
if !converged
tau /= 2
continue
end

rate(rate_cache, u_new_float, p, t_current + tau)
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
du .= zero(eltype(u_current))
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
du[spec_idx] += stoch * counts[j]
end
end
u_new = u_current + du

if any(<(0), u_new)
# Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3
tau /= 2
continue
end
# Ensure non-negativity, as per Cao et al. (2006), Section 3.3
for i in eachindex(u_new)
u_new[i] = max(u_new[i], 0)
end
t_new = t_current + tau

if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx])
push!(usave, copy(u_new))
push!(tsave, t_new)
if !isempty(saveat_times) && t_new >= saveat_times[save_idx]
save_idx += 1
end
end

u_current = u_new
t_current = t_new
end

sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
calculate_error=false,
interp=DiffEqBase.ConstantInterpolation(tsave, usave))
return sol
end

struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
backend::Backend
cpu_offload::Float64
Expand Down
Loading
Loading