@@ -7,25 +7,50 @@ NO_TSPAN_PROBS = Union{AbstractLinearProblem, AbstractNonlinearProblem,
77 AbstractIntegralProblem, AbstractSteadyStateProblem,
88 AbstractJumpProblem}
99
10- function init_call (_prob, args... ; merge_callbacks = true , kwargshandle = nothing ,
11- kwargs... )
12- kwargshandle = kwargshandle === nothing ? SciMLBase. KeywordArgError : kwargshandle
13- kwargshandle = has_kwargs (_prob) && haskey (_prob. kwargs, :kwargshandle ) ?
14- _prob. kwargs[:kwargshandle ] : kwargshandle
10+ """
11+ merge_problem_kwargs(prob; merge_callbacks=true, kwargs…)
1512
16- if has_kwargs (_prob)
17- if merge_callbacks && haskey (_prob. kwargs, :callback ) && haskey (kwargs, :callback )
13+ Merges kwargs stored in `prob.kwargs` with the provided `kwargs`, following
14+ DiffEq's standard merging rules:
15+ - Problem kwargs are merged first
16+ - Passed kwargs take precedence (i.e., they override problem kwargs)
17+ - If `merge_callbacks=true` and both prob and kwargs have callbacks, they are
18+ merged into a `CallbackSet` rather than one overriding the other
19+
20+ Returns the merged kwargs as a Base.pairs.
21+
22+ This function is intended for use by problem types that override `__solve` or `__init`
23+ and need to manually handle kwargs merging that would normally be done by `solve_call`
24+ or `init_call`.
25+ """
26+ function merge_problem_kwargs (prob; merge_callbacks = true , kwargs... )
27+
28+ # Special handling for callback merging
29+ if has_kwargs (prob)
30+ if merge_callbacks && haskey (prob. kwargs, :callback ) && haskey (kwargs, :callback )
1831 kwargs_temp = NamedTuple{
1932 Base. diff_names (Base. _nt_names (values (kwargs)),
2033 (:callback ,))}(values (kwargs))
2134 callbacks = NamedTuple {(:callback,)} ((DiffEqBase. CallbackSet (
22- _prob . kwargs[:callback ],
35+ prob . kwargs[:callback ],
2336 values (kwargs). callback),))
2437 kwargs = merge (kwargs_temp, callbacks)
2538 end
26- kwargs = isempty (_prob . kwargs) ? kwargs : merge (values (_prob . kwargs), kwargs)
39+ kwargs = isempty (prob . kwargs) ? kwargs : merge (values (prob . kwargs), kwargs)
2740 end
2841
42+ return kwargs
43+ end
44+
45+ function init_call (_prob, args... ; merge_callbacks = true , kwargshandle = nothing ,
46+ kwargs... )
47+ kwargshandle = kwargshandle === nothing ? SciMLBase. KeywordArgError : kwargshandle
48+ kwargshandle = has_kwargs (_prob) && haskey (_prob. kwargs, :kwargshandle ) ?
49+ _prob. kwargs[:kwargshandle ] : kwargshandle
50+
51+ # Merge problem kwargs with passed kwargs
52+ kwargs = merge_problem_kwargs (_prob; merge_callbacks, kwargs... )
53+
2954 checkkwargs (kwargshandle; kwargs... )
3055
3156 if _prob isa Union{ODEProblem, DAEProblem} && isnothing (_prob. u0)
@@ -87,18 +112,8 @@ function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothi
87112 kwargshandle = has_kwargs (_prob) && haskey (_prob. kwargs, :kwargshandle ) ?
88113 _prob. kwargs[:kwargshandle ] : kwargshandle
89114
90- if has_kwargs (_prob)
91- if merge_callbacks && haskey (_prob. kwargs, :callback ) && haskey (kwargs, :callback )
92- kwargs_temp = NamedTuple{
93- Base. diff_names (Base. _nt_names (values (kwargs)),
94- (:callback ,))}(values (kwargs))
95- callbacks = NamedTuple {(:callback,)} ((DiffEqBase. CallbackSet (
96- _prob. kwargs[:callback ],
97- values (kwargs). callback),))
98- kwargs = merge (kwargs_temp, callbacks)
99- end
100- kwargs = isempty (_prob. kwargs) ? kwargs : merge (values (_prob. kwargs), kwargs)
101- end
115+ # Merge problem kwargs with passed kwargs
116+ kwargs = merge_problem_kwargs (_prob; merge_callbacks, kwargs... )
102117
103118 checkkwargs (kwargshandle; kwargs... )
104119 if isdefined (_prob, :u0 )
@@ -853,18 +868,8 @@ function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callba
853868 _prob = get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
854869 end
855870
856- if has_kwargs (_prob)
857- if merge_callbacks && haskey (_prob. kwargs, :callback ) && haskey (kwargs, :callback )
858- kwargs_temp = NamedTuple{
859- Base. diff_names (Base. _nt_names (values (kwargs)),
860- (:callback ,))}(values (kwargs))
861- callbacks = NamedTuple {(:callback,)} ((DiffEqBase. CallbackSet (
862- _prob. kwargs[:callback ],
863- values (kwargs). callback),))
864- kwargs = merge (kwargs_temp, callbacks)
865- end
866- kwargs = isempty (_prob. kwargs) ? kwargs : merge (values (_prob. kwargs), kwargs)
867- end
871+ # Merge problem kwargs with passed kwargs
872+ kwargs = merge_problem_kwargs (_prob; merge_callbacks, kwargs... )
868873
869874 if length (args) > 1
870875 _concrete_solve_adjoint (_prob, alg, sensealg, u0, p, originator,
@@ -884,18 +889,8 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba
884889 _prob = get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
885890 end
886891
887- if has_kwargs (_prob)
888- if merge_callbacks && haskey (_prob. kwargs, :callback ) && haskey (kwargs, :callback )
889- kwargs_temp = NamedTuple{
890- Base. diff_names (Base. _nt_names (values (kwargs)),
891- (:callback ,))}(values (kwargs))
892- callbacks = NamedTuple {(:callback,)} ((DiffEqBase. CallbackSet (
893- _prob. kwargs[:callback ],
894- values (kwargs). callback),))
895- kwargs = merge (kwargs_temp, callbacks)
896- end
897- kwargs = isempty (_prob. kwargs) ? kwargs : merge (values (_prob. kwargs), kwargs)
898- end
892+ # Merge problem kwargs with passed kwargs
893+ kwargs = merge_problem_kwargs (_prob; merge_callbacks, kwargs... )
899894
900895 if length (args) > 1
901896 _concrete_solve_forward (_prob, alg, sensealg, u0, p, originator,
0 commit comments