-
-
Notifications
You must be signed in to change notification settings - Fork 38
SimpleExplicitTauLeaping solver #513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 35 commits
f5ad900
aae5d9f
230d508
25b8c16
07c429e
885ac59
0ec9d39
0e7ff26
89378c6
14e0be7
e7f975e
6e789cd
a8999f4
0125e21
1af190d
10f4ce3
6d3d900
2c03d67
3f90750
fb72149
7a7232a
fe7cec0
8e7ff16
bc770d1
b5f77f5
0b72d4c
5415947
b39390f
822562f
b572987
785266b
b47df7c
e02d432
092d361
7217cf0
cc3a78a
48fece2
12f84ba
98d64f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,5 +1,11 @@ | ||||||
| struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end | ||||||
|
|
||||||
| struct SimpleAdaptiveTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm | ||||||
| epsilon::T # Error control parameter | ||||||
| end | ||||||
|
|
||||||
| SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) | ||||||
|
|
||||||
| function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) | ||||||
| if !(jump_prob.aggregator isa PureLeaping) | ||||||
| @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ | ||||||
|
|
@@ -14,6 +20,19 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) | |||||
| jump_prob.regular_jump !== nothing | ||||||
| end | ||||||
|
|
||||||
| function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping) | ||||||
| if !(jump_prob.aggregator isa PureLeaping) | ||||||
| @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ | ||||||
| JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \ | ||||||
| Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release." | ||||||
| end | ||||||
| isempty(jump_prob.jump_callback.continuous_callbacks) && | ||||||
| isempty(jump_prob.jump_callback.discrete_callbacks) && | ||||||
| isempty(jump_prob.constant_jumps) && | ||||||
| isempty(jump_prob.variable_jumps) && | ||||||
| jump_prob.massaction_jump !== nothing | ||||||
| end | ||||||
|
|
||||||
| function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; | ||||||
| seed = nothing, dt = error("dt is required for SimpleTauLeaping.")) | ||||||
| validate_pure_leaping_inputs(jump_prob, alg) || | ||||||
|
|
@@ -61,6 +80,215 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; | |||||
| interp = DiffEqBase.ConstantInterpolation(t, u)) | ||||||
| end | ||||||
|
|
||||||
| function compute_hor(reactant_stoch, numjumps) | ||||||
| # Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV. | ||||||
| # HOR is the sum of stoichiometric coefficients of reactants in reaction j. | ||||||
| hor = zeros(Int, numjumps) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Instead of |
||||||
| for j in 1:numjumps | ||||||
| order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if order > 3 | ||||||
| error("Reaction $j has order $order, which is not supported (maximum order is 3).") | ||||||
| end | ||||||
| hor[j] = order | ||||||
| end | ||||||
| return hor | ||||||
| end | ||||||
|
|
||||||
| function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) | ||||||
| # Precompute reaction conditions for each species i, storing reactions j where i is a reactant, | ||||||
| # along with stoichiometry (nu_ij) and HOR (hor_j), to optimize compute_gi. | ||||||
| # Reactant stoichiometry is used per Cao et al. (2006), Section IV, for g_i calculations. | ||||||
| reaction_conditions = [Vector() for _ in 1:numspecies] | ||||||
| for j in 1:numjumps | ||||||
| for (spec_idx, stoch) in reactant_stoch[j] | ||||||
| if stoch > 0 # Species is a reactant | ||||||
| push!(reaction_conditions[spec_idx], (j, stoch, hor[j])) | ||||||
| end | ||||||
| end | ||||||
| end | ||||||
| return reaction_conditions | ||||||
| end | ||||||
|
|
||||||
| function compute_gi(u, reaction_conditions, i) | ||||||
| # Compute g_i for species i to bound the relative change in propensity functions, | ||||||
| # as per Cao et al. (2006), Section IV, equation (27). | ||||||
| # g_i is determined by the highest order of reaction (HOR) where species i is a reactant: | ||||||
| # - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1 | ||||||
| # - HOR = 2 (second-order): | ||||||
| # - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2 | ||||||
| # - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1) | ||||||
| # - HOR = 3 (third-order): | ||||||
| # - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3 | ||||||
| # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1)) | ||||||
| # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2) | ||||||
| # Uses precomputed reaction_conditions to optimize checks for HOR = 2 or 3 with nu_ij >= 2. | ||||||
| max_hor = maximum(isempty(reaction_conditions[i]) ? 0 : [hor_j for (j, nu_ij, hor_j) in reaction_conditions[i]]) | ||||||
| max_gi = 1 | ||||||
| for (j, nu_ij, hor_j) in reaction_conditions[i] | ||||||
| if hor_j == max_hor | ||||||
| if hor_j == 1 | ||||||
| max_gi = max(max_gi, 1) | ||||||
| elseif hor_j == 2 | ||||||
| if nu_ij == 1 | ||||||
| max_gi = max(max_gi, 2) | ||||||
| elseif nu_ij == 2 | ||||||
| if u[i] > 1 # Ensure x_i - 1 > 0 | ||||||
| gi = 2 + 1 / (u[i] - 1) | ||||||
| max_gi = max(max_gi, ceil(Int64, gi)) | ||||||
| end | ||||||
| end | ||||||
| elseif hor_j == 3 | ||||||
| if nu_ij == 1 | ||||||
| max_gi = max(max_gi, 3) | ||||||
| elseif nu_ij == 2 | ||||||
| if u[i] > 1 # Ensure x_i - 1 > 0 | ||||||
| gi = 1.5 * (2 + 1 / (u[i] - 1)) | ||||||
| max_gi = max(max_gi, ceil(Int64, gi)) | ||||||
| end | ||||||
| elseif nu_ij == 3 | ||||||
| if u[i] > 2 # Ensure x_i - 2 > 0 | ||||||
| gi = 3 + 1 / (u[i] - 1) + 2 / (u[i] - 2) | ||||||
| max_gi = max(max_gi, ceil(Int64, gi)) | ||||||
| end | ||||||
| end | ||||||
| end | ||||||
| end | ||||||
| end | ||||||
| return max_gi | ||||||
sivasathyaseeelan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| end | ||||||
|
|
||||||
| function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin, reaction_conditions, numjumps) | ||||||
| # Compute the tau-leaping step-size using equation (8) from Cao et al. (2006): | ||||||
| # tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) } | ||||||
| # where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b): | ||||||
| # mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) | ||||||
| # I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified). | ||||||
| rate(rate_cache, u, p, t) | ||||||
| tau = Inf | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| for i in 1:length(u) | ||||||
| mu = zero(eltype(u)) | ||||||
| sigma2 = zero(eltype(u)) | ||||||
| for j in 1:size(nu, 2) | ||||||
| mu += nu[i, j] * rate_cache[j] # Equation (9a) | ||||||
| sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b) | ||||||
| end | ||||||
| gi = compute_gi(u, reaction_conditions, i) | ||||||
| bound = max(epsilon * u[i] / gi, 1.0) # max(epsilon * x_i / g_i, 1) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't use explicitly typed |
||||||
| mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Similar for the |
||||||
| sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) | ||||||
| tau = min(tau, mu_term, sigma_term) # Equation (8) | ||||||
| end | ||||||
| return max(tau, dtmin) | ||||||
isaacsas marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| end | ||||||
|
|
||||||
| function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; | ||||||
| seed = nothing, | ||||||
| dtmin = 1e-10, | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Use the type of |
||||||
| saveat = nothing) | ||||||
| validate_pure_leaping_inputs(jump_prob, alg) || | ||||||
| error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") | ||||||
|
|
||||||
| @unpack prob, rng = jump_prob | ||||||
| (seed !== nothing) && seed!(rng, seed) | ||||||
|
|
||||||
| maj = jump_prob.massaction_jump | ||||||
| numjumps = get_num_majumps(maj) | ||||||
| rj = jump_prob.regular_jump | ||||||
| # Extract rates | ||||||
| rate = rj !== nothing ? rj.rate : | ||||||
| (out, u, p, t) -> begin | ||||||
| for j in 1:numjumps | ||||||
| out[j] = evalrxrate(u, j, maj) | ||||||
| end | ||||||
| end | ||||||
|
Comment on lines
+196
to
+200
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this a standalone function instead of inline (you can use a functor design if you need to save internal state). |
||||||
| c = rj !== nothing ? rj.c : nothing | ||||||
| u0 = copy(prob.u0) | ||||||
| tspan = prob.tspan | ||||||
| p = prob.p | ||||||
|
|
||||||
| # Initialize current state and saved history | ||||||
| u_current = copy(u0) | ||||||
| t_current = tspan[1] | ||||||
| usave = [copy(u0)] | ||||||
| tsave = [tspan[1]] | ||||||
| rate_cache = zeros(float(eltype(u0)), numjumps) | ||||||
| counts = zero(rate_cache) | ||||||
| du = similar(u0) | ||||||
| t_end = tspan[2] | ||||||
| epsilon = alg.epsilon | ||||||
|
|
||||||
| # Extract net stoichiometry for state updates | ||||||
| nu = zeros(float(eltype(u0)), length(u0), numjumps) | ||||||
| for j in 1:numjumps | ||||||
| for (spec_idx, stoch) in maj.net_stoch[j] | ||||||
| nu[spec_idx, j] = stoch | ||||||
| end | ||||||
| end | ||||||
| # Extract reactant stoichiometry for hor and gi | ||||||
| reactant_stoch = maj.reactant_stoch | ||||||
| hor = compute_hor(reactant_stoch, numjumps) | ||||||
| reaction_conditions = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) | ||||||
|
|
||||||
| # Set up saveat_times | ||||||
| saveat_times = nothing | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Not needed. |
||||||
| if isnothing(saveat) | ||||||
| saveat_times = Vector{typeof(tspan[1])}() | ||||||
| elseif saveat isa Number | ||||||
| saveat_times = collect(range(tspan[1], tspan[2], step=saveat)) | ||||||
| else | ||||||
| saveat_times = collect(saveat) | ||||||
| end | ||||||
sivasathyaseeelan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| save_idx = 1 | ||||||
|
|
||||||
| while t_current < t_end | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For type stability reasons, this while loop should probably be a separate function you call. |
||||||
| rate(rate_cache, u_current, p, t_current) | ||||||
| tau = compute_tau_explicit(u_current, rate_cache, nu, p, t_current, epsilon, rate, dtmin, reaction_conditions, numjumps) | ||||||
| tau = min(tau, t_end - t_current) | ||||||
| if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] | ||||||
| tau = saveat_times[save_idx] - t_current | ||||||
| end | ||||||
| counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, if a particular rate is <= 0 why not just set the count to zero directly and avoid the call to |
||||||
| du .= 0 | ||||||
| if c !== nothing | ||||||
| c(du, u_current, p, t_current, counts, nothing) | ||||||
| else | ||||||
| for j in 1:numjumps | ||||||
| for (spec_idx, stoch) in maj.net_stoch[j] | ||||||
| du[spec_idx] += stoch * counts[j] | ||||||
| end | ||||||
| end | ||||||
| end | ||||||
| u_new = u_current + du | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is allocating... Have separate pre-declatred vectors for |
||||||
| if any(<(0), u_new) | ||||||
| # Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3 | ||||||
| tau /= 2 | ||||||
sivasathyaseeelan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| continue | ||||||
| end | ||||||
| for i in eachindex(u_new) | ||||||
| u_new[i] = max(u_new[i], 0) | ||||||
| end | ||||||
| t_new = t_current + tau | ||||||
|
Comment on lines
+266
to
+269
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't the step rejection if any component of |
||||||
|
|
||||||
| # Save state if at a saveat time or if saveat is empty | ||||||
| if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) | ||||||
| push!(usave, copy(u_new)) | ||||||
| push!(tsave, t_new) | ||||||
| if !isempty(saveat_times) && t_new >= saveat_times[save_idx] | ||||||
| save_idx += 1 | ||||||
| end | ||||||
| end | ||||||
|
|
||||||
| u_current = u_new | ||||||
| t_current = t_new | ||||||
| end | ||||||
|
|
||||||
| sol = DiffEqBase.build_solution(prob, alg, tsave, usave, | ||||||
| calculate_error=false, | ||||||
| interp=DiffEqBase.ConstantInterpolation(tsave, usave)) | ||||||
| return sol | ||||||
| end | ||||||
|
|
||||||
| struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm | ||||||
| backend::Backend | ||||||
| cpu_offload::Float64 | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And please update other functions accordingly. Place these kind of design comments before the function, not right inside it.