Skip to content

Commit c04f690

Browse files
Merge pull request #1224 from isaacsas/apifun_for_kwarg_merging
API function for kwarg merging
2 parents 54afe65 + a87c740 commit c04f690

File tree

4 files changed

+269
-48
lines changed

4 files changed

+269
-48
lines changed

src/solve.jl

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,50 @@ NO_TSPAN_PROBS = Union{AbstractLinearProblem, AbstractNonlinearProblem,
77
AbstractIntegralProblem, AbstractSteadyStateProblem,
88
AbstractJumpProblem}
99

10-
function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing,
11-
kwargs...)
12-
kwargshandle = kwargshandle === nothing ? SciMLBase.KeywordArgError : kwargshandle
13-
kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ?
14-
_prob.kwargs[:kwargshandle] : kwargshandle
10+
"""
11+
merge_problem_kwargs(prob; merge_callbacks=true, kwargs…)
1512
16-
if has_kwargs(_prob)
17-
if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback)
13+
Merges kwargs stored in `prob.kwargs` with the provided `kwargs`, following
14+
DiffEq's standard merging rules:
15+
- Problem kwargs are merged first
16+
- Passed kwargs take precedence (i.e., they override problem kwargs)
17+
- If `merge_callbacks=true` and both prob and kwargs have callbacks, they are
18+
merged into a `CallbackSet` rather than one overriding the other
19+
20+
Returns the merged kwargs as a Base.pairs.
21+
22+
This function is intended for use by problem types that override `__solve` or `__init`
23+
and need to manually handle kwargs merging that would normally be done by `solve_call`
24+
or `init_call`.
25+
"""
26+
function merge_problem_kwargs(prob; merge_callbacks = true, kwargs...)
27+
28+
# Special handling for callback merging
29+
if has_kwargs(prob)
30+
if merge_callbacks && haskey(prob.kwargs, :callback) && haskey(kwargs, :callback)
1831
kwargs_temp = NamedTuple{
1932
Base.diff_names(Base._nt_names(values(kwargs)),
2033
(:callback,))}(values(kwargs))
2134
callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(
22-
_prob.kwargs[:callback],
35+
prob.kwargs[:callback],
2336
values(kwargs).callback),))
2437
kwargs = merge(kwargs_temp, callbacks)
2538
end
26-
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
39+
kwargs = isempty(prob.kwargs) ? kwargs : merge(values(prob.kwargs), kwargs)
2740
end
2841

42+
return kwargs
43+
end
44+
45+
function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing,
46+
kwargs...)
47+
kwargshandle = kwargshandle === nothing ? SciMLBase.KeywordArgError : kwargshandle
48+
kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ?
49+
_prob.kwargs[:kwargshandle] : kwargshandle
50+
51+
# Merge problem kwargs with passed kwargs
52+
kwargs = merge_problem_kwargs(_prob; merge_callbacks, kwargs...)
53+
2954
checkkwargs(kwargshandle; kwargs...)
3055

3156
if _prob isa Union{ODEProblem, DAEProblem} && isnothing(_prob.u0)
@@ -87,18 +112,8 @@ function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothi
87112
kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ?
88113
_prob.kwargs[:kwargshandle] : kwargshandle
89114

90-
if has_kwargs(_prob)
91-
if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback)
92-
kwargs_temp = NamedTuple{
93-
Base.diff_names(Base._nt_names(values(kwargs)),
94-
(:callback,))}(values(kwargs))
95-
callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(
96-
_prob.kwargs[:callback],
97-
values(kwargs).callback),))
98-
kwargs = merge(kwargs_temp, callbacks)
99-
end
100-
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
101-
end
115+
# Merge problem kwargs with passed kwargs
116+
kwargs = merge_problem_kwargs(_prob; merge_callbacks, kwargs...)
102117

103118
checkkwargs(kwargshandle; kwargs...)
104119
if isdefined(_prob, :u0)
@@ -853,18 +868,8 @@ function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callba
853868
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
854869
end
855870

856-
if has_kwargs(_prob)
857-
if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback)
858-
kwargs_temp = NamedTuple{
859-
Base.diff_names(Base._nt_names(values(kwargs)),
860-
(:callback,))}(values(kwargs))
861-
callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(
862-
_prob.kwargs[:callback],
863-
values(kwargs).callback),))
864-
kwargs = merge(kwargs_temp, callbacks)
865-
end
866-
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
867-
end
871+
# Merge problem kwargs with passed kwargs
872+
kwargs = merge_problem_kwargs(_prob; merge_callbacks, kwargs...)
868873

869874
if length(args) > 1
870875
_concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator,
@@ -884,18 +889,8 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba
884889
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
885890
end
886891

887-
if has_kwargs(_prob)
888-
if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback)
889-
kwargs_temp = NamedTuple{
890-
Base.diff_names(Base._nt_names(values(kwargs)),
891-
(:callback,))}(values(kwargs))
892-
callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(
893-
_prob.kwargs[:callback],
894-
values(kwargs).callback),))
895-
kwargs = merge(kwargs_temp, callbacks)
896-
end
897-
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
898-
end
892+
# Merge problem kwargs with passed kwargs
893+
kwargs = merge_problem_kwargs(_prob; merge_callbacks, kwargs...)
899894

900895
if length(args) > 1
901896
_concrete_solve_forward(_prob, alg, sensealg, u0, p, originator,
Lines changed: 143 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
using OrdinaryDiffEq
2-
3-
# Auto callback merging
1+
using OrdinaryDiffEq, DiffEqBase, Test
42

3+
# Basic auto callback merging test
54
do_nothing = DiscreteCallback((u, t, integrator) -> true,
65
integrator -> nothing)
76
problem = ODEProblem((u, p, t) -> -u,
@@ -10,3 +9,144 @@ problem = ODEProblem((u, p, t) -> -u,
109
solve(problem, Euler(),
1110
dt = 0.1,
1211
callback = do_nothing)
12+
13+
@testset "Callback merging through solve" begin
14+
# Test that callbacks passed to problem constructor are properly merged
15+
# with callbacks passed to solve
16+
17+
function lorenz!(du, u, p, t)
18+
du[1] = 10.0(u[2] - u[1])
19+
du[2] = u[1] * (28.0 - u[3]) - u[2]
20+
du[3] = u[1] * u[2] - (8 / 3) * u[3]
21+
end
22+
u0 = [1.0, 0.0, 0.0]
23+
tspan = (0.0, 1.0)
24+
25+
# Create two callbacks that track when they're called
26+
cb1_called = Ref(false)
27+
condition1(u, t, integrator) = t - 0.3
28+
affect1!(integrator) = (cb1_called[] = true)
29+
cb1 = ContinuousCallback(condition1, affect1!)
30+
31+
cb2_called = Ref(false)
32+
condition2(u, t, integrator) = t - 0.7
33+
affect2!(integrator) = (cb2_called[] = true)
34+
cb2 = ContinuousCallback(condition2, affect2!)
35+
36+
# Test 1: Callback in problem constructor only
37+
cb1_called[] = false
38+
prob = ODEProblem(lorenz!, u0, tspan; callback = cb1)
39+
sol = solve(prob, Tsit5())
40+
@test cb1_called[]
41+
@test sol.t[end] 1.0
42+
43+
# Test 2: Callback in solve only
44+
cb2_called[] = false
45+
prob = ODEProblem(lorenz!, u0, tspan)
46+
sol = solve(prob, Tsit5(); callback = cb2)
47+
@test cb2_called[]
48+
@test sol.t[end] 1.0
49+
50+
# Test 3: Callbacks in both (should merge by default)
51+
cb1_called[] = false
52+
cb2_called[] = false
53+
prob = ODEProblem(lorenz!, u0, tspan; callback = cb1)
54+
sol = solve(prob, Tsit5(); callback = cb2)
55+
@test cb1_called[]
56+
@test cb2_called[]
57+
@test sol.t[end] 1.0
58+
59+
# Test 4: merge_callbacks = false (solve callback should override)
60+
cb1_called[] = false
61+
cb2_called[] = false
62+
prob = ODEProblem(lorenz!, u0, tspan; callback = cb1)
63+
sol = solve(prob, Tsit5(); callback = cb2, merge_callbacks = false)
64+
@test !cb1_called[] # cb1 should not be called
65+
@test cb2_called[] # cb2 should be called
66+
@test sol.t[end] 1.0
67+
end
68+
69+
@testset "Callback merging through init" begin
70+
# Test that callbacks are properly merged when using init instead of solve
71+
72+
function simple!(du, u, p, t)
73+
du[1] = -u[1]
74+
end
75+
u0 = [1.0]
76+
tspan = (0.0, 1.0)
77+
78+
cb1_called = Ref(false)
79+
condition1(u, t, integrator) = t - 0.3
80+
affect1!(integrator) = (cb1_called[] = true)
81+
cb1 = ContinuousCallback(condition1, affect1!)
82+
83+
cb2_called = Ref(false)
84+
condition2(u, t, integrator) = t - 0.7
85+
affect2!(integrator) = (cb2_called[] = true)
86+
cb2 = ContinuousCallback(condition2, affect2!)
87+
88+
# Test: Callbacks in both problem and init (should merge)
89+
cb1_called[] = false
90+
cb2_called[] = false
91+
prob = ODEProblem(simple!, u0, tspan; callback = cb1)
92+
integrator = init(prob, Tsit5(); callback = cb2)
93+
solve!(integrator)
94+
@test cb1_called[]
95+
@test cb2_called[]
96+
end
97+
98+
@testset "Other kwargs merging" begin
99+
# Test that non-callback kwargs are properly merged by checking they're accessible
100+
101+
function simple!(du, u, p, t)
102+
du[1] = -u[1]
103+
end
104+
u0 = [1.0]
105+
tspan = (0.0, 1.0)
106+
107+
# Test that problem kwargs are preserved and solve kwargs override
108+
prob = ODEProblem(simple!, u0, tspan; abstol = 1e-10, saveat = 0.1)
109+
110+
# Both abstol and saveat should be used
111+
sol = solve(prob, Tsit5())
112+
@test sol.t [i * 0.1 for i in 0:10] # saveat from problem kwargs
113+
114+
# Override saveat, keep abstol
115+
sol = solve(prob, Tsit5(); saveat = 0.5)
116+
@test sol.t [0.0, 0.5, 1.0] # saveat from solve kwargs
117+
118+
# Test with save_everystep
119+
prob_everystep = ODEProblem(simple!, u0, tspan; save_everystep = false, saveat = 0.5)
120+
sol = solve(prob_everystep, Tsit5())
121+
@test sol.t [0.0, 0.5, 1.0] # Only saved at saveat times, not every step
122+
123+
# Override save_everystep
124+
sol = solve(prob_everystep, Tsit5(); save_everystep = true)
125+
@test length(sol.t) > 3 # Should save at more than just saveat times
126+
end
127+
128+
@testset "kwargshandle merging" begin
129+
# Test that kwargshandle is properly respected from problem kwargs
130+
131+
function simple!(du, u, p, t)
132+
du[1] = -u[1]
133+
end
134+
u0 = [1.0]
135+
tspan = (0.0, 1.0)
136+
137+
# Problem with KeywordArgSilent should not warn on invalid kwargs
138+
prob = ODEProblem(simple!, u0, tspan;
139+
kwargshandle = SciMLBase.KeywordArgSilent,
140+
invalid_kwarg = "should be ignored")
141+
@test_nowarn sol = solve(prob, Tsit5())
142+
143+
# Problem with KeywordArgWarn and invalid kwarg should warn
144+
prob = ODEProblem(simple!, u0, tspan;
145+
kwargshandle = SciMLBase.KeywordArgWarn,
146+
invalid_kwarg = "should warn")
147+
@test_logs (:warn, SciMLBase.KWARGWARN_MESSAGE) sol = solve(prob, Tsit5())
148+
149+
# Default should error on invalid kwargs
150+
prob = ODEProblem(simple!, u0, tspan; invalid_kwarg = "should error")
151+
@test_throws SciMLBase.CommonKwargError sol = solve(prob, Tsit5())
152+
end

test/problem_kwargs_merging.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
using DiffEqBase, Test
2+
3+
# Test merge_problem_kwargs function - only using DiffEqBase directly
4+
5+
@testset "merge_problem_kwargs API" begin
6+
# Create a simple problem for testing
7+
function f(du, u, p, t)
8+
du[1] = -u[1]
9+
end
10+
11+
# Test 1: Problem with no kwargs
12+
prob_no_kwargs = ODEProblem(f, [1.0], (0.0, 1.0))
13+
kwargs_in = (abstol = 1e-6, reltol = 1e-6)
14+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_no_kwargs; kwargs_in...)
15+
@test kwargs_out == Base.pairs(kwargs_in)
16+
17+
# Test 2: Problem with empty kwargs
18+
prob_empty_kwargs = ODEProblem(f, [1.0], (0.0, 1.0); Dict{Symbol,Any}()...)
19+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_empty_kwargs; kwargs_in...)
20+
@test kwargs_out == Base.pairs(kwargs_in)
21+
22+
# Test 3: Problem kwargs are preserved, passed kwargs take precedence
23+
prob_with_kwargs = ODEProblem(f, [1.0], (0.0, 1.0); abstol = 1e-8, reltol = 1e-8)
24+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_with_kwargs; (abstol = 1e-6,)...)
25+
@test kwargs_out.abstol == 1e-6 # Passed kwarg takes precedence
26+
@test kwargs_out.reltol == 1e-8 # Problem kwarg is preserved
27+
28+
# Test 4: Multiple kwargs from problem and passed
29+
prob_multi = ODEProblem(f, [1.0], (0.0, 1.0); abstol = 1e-8, saveat = 0.1)
30+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_multi; (reltol = 1e-6, maxiters = 1000)...)
31+
@test kwargs_out.abstol == 1e-8 # From problem
32+
@test kwargs_out.saveat == 0.1 # From problem
33+
@test kwargs_out.reltol == 1e-6 # From passed kwargs
34+
@test kwargs_out.maxiters == 1000 # From passed kwargs
35+
36+
# Test 5: Callback merging disabled
37+
condition(u, t, integrator) = t - 0.5
38+
affect!(integrator) = (integrator.u[1] += 1.0)
39+
cb1 = ContinuousCallback(condition, affect!)
40+
cb2 = ContinuousCallback(condition, affect!)
41+
42+
prob_with_cb = ODEProblem(f, [1.0], (0.0, 1.0); callback = cb1)
43+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_with_cb; merge_callbacks = false,
44+
(callback = cb2,)...);
45+
@test kwargs_out.callback === cb2 # cb2 should override cb1
46+
47+
# Test 6: Callback merging enabled (default)
48+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_with_cb; (callback = cb2,)...)
49+
@test kwargs_out.callback isa DiffEqBase.CallbackSet
50+
@test length(kwargs_out.callback.continuous_callbacks) == 2
51+
52+
# Test 7: Callback merging with CallbackSet
53+
cb3 = ContinuousCallback(condition, affect!)
54+
cbset = CallbackSet(cb3)
55+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_with_cb; (callback = cbset,)...)
56+
@test kwargs_out.callback isa DiffEqBase.CallbackSet
57+
@test length(kwargs_out.callback.continuous_callbacks) == 2
58+
59+
# Test 8: Only problem has callback
60+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_with_cb; (abstol = 1e-6,)...)
61+
@test kwargs_out.callback === cb1
62+
@test kwargs_out.abstol == 1e-6
63+
64+
# Test 9: Only passed kwargs have callback
65+
prob_no_cb = ODEProblem(f, [1.0], (0.0, 1.0); abstol = 1e-8)
66+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_no_cb; (callback = cb2, reltol = 1e-6)...)
67+
@test kwargs_out.callback === cb2
68+
@test kwargs_out.abstol == 1e-8
69+
@test kwargs_out.reltol == 1e-6
70+
71+
# Test 10: DiscreteCallback merging
72+
dcb1 = DiscreteCallback((u, t, integrator) -> true, integrator -> nothing)
73+
dcb2 = DiscreteCallback((u, t, integrator) -> true, integrator -> nothing)
74+
prob_with_dcb = ODEProblem(f, [1.0], (0.0, 1.0); callback = dcb1)
75+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_with_dcb; (callback = dcb2,)...)
76+
@test kwargs_out.callback isa DiffEqBase.CallbackSet
77+
@test length(kwargs_out.callback.discrete_callbacks) == 2
78+
79+
# Test 11: Mixed Continuous and Discrete callback merging
80+
prob_with_ccb = ODEProblem(f, [1.0], (0.0, 1.0); callback = cb1)
81+
kwargs_out = DiffEqBase.merge_problem_kwargs(prob_with_ccb; (callback = dcb1,)...)
82+
@test kwargs_out.callback isa DiffEqBase.CallbackSet
83+
@test length(kwargs_out.callback.continuous_callbacks) == 1
84+
@test length(kwargs_out.callback.discrete_callbacks) == 1
85+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ end
3939
@time @safetestset "ForwardDiff Dual Detection" include("forwarddiff_dual_detection.jl")
4040
@time @safetestset "ODE default norm" include("ode_default_norm.jl")
4141
@time @safetestset "ODE default unstable check" include("ode_default_unstable_check.jl")
42+
@time @safetestset "Problem Kwargs Merging" include("problem_kwargs_merging.jl")
4243
end
4344

4445
if !is_APPVEYOR && GROUP == "Downstream"

0 commit comments

Comments
 (0)