Skip to content

Commit 2ac8579

Browse files
committed
solver interface run_ci
1 parent e2abc77 commit 2ac8579

File tree

13 files changed

+64
-111
lines changed

13 files changed

+64
-111
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DirectTrajectoryOptimization"
22
uuid = "a9f42406-efe7-414c-8b71-df971cc98041"
33
authors = ["thowell <[email protected]>"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"

examples/acrobot/acrobot.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ bnds = [bnd1, [bndt for t = 2:T-1]..., bndT]
103103
cons = [Constraint() for t = 1:T]
104104

105105
# ## problem
106-
p = ProblemData(obj, dyn, cons, bnds, options=Options())
106+
p = solver(dyn, obj, cons, bnds, options=Options{Float64}())
107107

108108
@variables z[1:p.nlp.num_var]
109109

examples/car/car.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ conT = Constraint(obs, nx, 0, nw, idx_ineq=collect(1:1))
6060
cons = [[cont for t = 1:T-1]..., conT]
6161

6262
# ## problem
63-
p = ProblemData(obj, dyn, cons, bnds, options=Options())
63+
p = solver(dyn, obj, cons, bnds)
6464

6565
# ## initialize
6666
x_interpolation = linear_interpolation(x1, xT, T)

src/DirectTrajectoryOptimization.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ include("constraints.jl")
1313
include("bounds.jl")
1414
include("general_constraint.jl")
1515
include("dynamics.jl")
16-
include("solver.jl")
1716
include("data.jl")
17+
include("solver.jl")
1818
include("moi.jl")
1919
include("utils.jl")
2020

@@ -27,9 +27,6 @@ export Bound, Bounds, Constraint, Constraints, GeneralConstraint
2727
# dynamics
2828
export Dynamics
2929

30-
# problem
31-
export ProblemData
32-
3330
# solver
3431
export Solver, Options, initialize_states!, initialize_controls!, solve!, get_trajectory
3532

src/constraints.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ end
1818

1919
Constraints{T} = Vector{Constraint{T}} where T
2020

21-
function Constraint(f::Function, nx::Int, nu::Int, nw::Int; idx_ineq=collect(1:0), eval_hess=false)
21+
function Constraint(f::Function, nx::Int, nu::Int; nw::Int=0, idx_ineq=collect(1:0), eval_hess=false)
2222
#TODO: option to load/save methods
2323
@variables x[1:nx], u[1:nu], w[1:nw]
2424
val = f(x, u, w)

src/data.jl

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -232,48 +232,4 @@ function duals!(λ_dyn::Vector{Vector{T}}, λ_stage::Vector{Vector{T}}, λ_gen::
232232
λ_stage[t] .= @views λ[idx]
233233
end
234234
λ_gen .= λ[idx_gen]
235-
end
236-
237-
struct ProblemData{T} <: MOI.AbstractNLPEvaluator
238-
nlp::NLPData{T}
239-
s_data::SolverData
240-
end
241-
242-
function ProblemData(obj::Objective{T}, dyn::Vector{Dynamics{T}}, cons::Constraints{T}, bnds::Bounds{T};
243-
eval_hess=false,
244-
general_constraint=GeneralConstraint(),
245-
options=Options(),
246-
w=[[zeros(nw) for nw in dimensions(dyn)[3]]..., zeros(0)]) where T
247-
248-
trajopt = TrajectoryOptimizationData(obj, dyn, cons, bnds, w=w)
249-
nlp = NLPData(trajopt, general_constraint=general_constraint, eval_hess=eval_hess)
250-
s_data = SolverData(nlp, options=options)
251-
252-
ProblemData(nlp, s_data)
253-
end
254-
255-
function initialize_states!(p::ProblemData, x)
256-
for (t, xt) in enumerate(x)
257-
n = length(xt)
258-
for i = 1:n
259-
MOI.set(p.s_data.solver, MOI.VariablePrimalStart(), p.s_data.z[p.nlp.idx.x[t][i]], xt[i])
260-
end
261-
end
262-
end
263-
264-
function initialize_controls!(p::ProblemData, u)
265-
for (t, ut) in enumerate(u)
266-
m = length(ut)
267-
for j = 1:m
268-
MOI.set(p.s_data.solver, MOI.VariablePrimalStart(), p.s_data.z[p.nlp.idx.u[t][j]], ut[j])
269-
end
270-
end
271-
end
272-
273-
function get_trajectory(p::ProblemData)
274-
return p.nlp.trajopt.x, p.nlp.trajopt.u[1:end-1]
275-
end
276-
277-
function solve!(p::ProblemData)
278-
MOI.optimize!(p.s_data.solver)
279235
end

src/general_constraint.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct GeneralConstraint{T}
1515
idx_ineq::Vector{Int}
1616
end
1717

18-
function GeneralConstraint(f::Function, nz::Int, nw::Int; idx_ineq=collect(1:0), eval_hess=false)
18+
function GeneralConstraint(f::Function, nz::Int; nw::Int=0, idx_ineq=collect(1:0), eval_hess=false)
1919
#TODO: option to load/save methods
2020
@variables z[1:nz], w[1:nw]
2121
val = f(z, w)

src/objective.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ struct Cost{T}
99
hess_cache::Vector{T}
1010
end
1111

12-
function Cost(f::Function, nx::Int, nu::Int, nw::Int; eval_hess=false)
12+
function Cost(f::Function, nx::Int, nu::Int; nw::Int=0, eval_hess=false)
1313
#TODO: option to load/save methods
1414
@variables x[1:nx], u[1:nu], w[1:nw]
1515
val = f(x, u, w)

src/solver.jl

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,45 +31,46 @@
3131
# timing_statistics = :no
3232
end
3333

34-
# struct Solver{T}
35-
# p::Problem{T}
36-
# nlp_bounds::Vector{MOI.NLPBoundsPair}
37-
# block_data::MOI.NLPBlockData
38-
# solver::Ipopt.Optimizer
39-
# z::Vector{MOI.VariableIndex}
40-
# end
34+
struct Solver{T} <: MOI.AbstractNLPEvaluator
35+
nlp::NLPData{T}
36+
s_data::SolverData
37+
end
38+
39+
function solver(dyn::Vector{Dynamics{T}}, obj::Objective{T}, cons::Constraints{T}, bnds::Bounds{T};
40+
options=Options{T}(),
41+
w=[[zeros(nw) for nw in dimensions(dyn)[3]]..., zeros(0)],
42+
eval_hess=false,
43+
general_constraint=GeneralConstraint()) where T
4144

42-
# function Solver(trajopt::TrajectoryOptimizationProblem; eval_hess=false, options=Options())
43-
# p = Problem(trajopt, eval_hess=eval_hess)
44-
45-
# nlp_bounds = MOI.NLPBoundsPair.(p.con_bnds...)
46-
# block_data = MOI.NLPBlockData(nlp_bounds, p, true)
47-
48-
# # instantiate NLP solver
49-
# solver = Ipopt.Optimizer()
45+
trajopt = TrajectoryOptimizationData(obj, dyn, cons, bnds, w=w)
46+
nlp = NLPData(trajopt, general_constraint=general_constraint, eval_hess=eval_hess)
47+
s_data = SolverData(nlp, options=options)
5048

51-
# # set NLP solver options
52-
# for name in fieldnames(typeof(options))
53-
# solver.options[String(name)] = getfield(options, name)
54-
# end
55-
56-
# z = MOI.add_variables(solver, p.num_var)
57-
58-
# for i = 1:p.num_var
59-
# MOI.add_constraint(solver, z[i], MOI.LessThan(p.var_bnds[2][i]))
60-
# MOI.add_constraint(solver, z[i], MOI.GreaterThan(p.var_bnds[1][i]))
61-
# end
62-
63-
# MOI.set(solver, MOI.NLPBlock(), block_data)
64-
# MOI.set(solver, MOI.ObjectiveSense(), MOI.MIN_SENSE)
65-
66-
# return Solver(p, nlp_bounds, block_data, solver, z)
67-
# end
49+
Solver(nlp, s_data)
50+
end
6851

69-
# function initialize!(s::Solver, z)
70-
# for i = 1:s.p.num_var
71-
# MOI.set(s.solver, MOI.VariablePrimalStart(), s.z[i], z[i])
72-
# end
73-
# end
52+
function initialize_states!(p::Solver, x)
53+
for (t, xt) in enumerate(x)
54+
n = length(xt)
55+
for i = 1:n
56+
MOI.set(p.s_data.solver, MOI.VariablePrimalStart(), p.s_data.z[p.nlp.idx.x[t][i]], xt[i])
57+
end
58+
end
59+
end
7460

61+
function initialize_controls!(p::Solver, u)
62+
for (t, ut) in enumerate(u)
63+
m = length(ut)
64+
for j = 1:m
65+
MOI.set(p.s_data.solver, MOI.VariablePrimalStart(), p.s_data.z[p.nlp.idx.u[t][j]], ut[j])
66+
end
67+
end
68+
end
7569

70+
function get_trajectory(p::Solver)
71+
return p.nlp.trajopt.x, p.nlp.trajopt.u[1:end-1]
72+
end
73+
74+
function solve!(p::Solver)
75+
MOI.optimize!(p.s_data.solver)
76+
end

test/constraints.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
ct = (x, u, w) -> [-ones(nx) - x; x - ones(nx)]
1414
cT = (x, u, w) -> x
1515

16-
cont = Constraint(ct, nx, nu, nw, idx_ineq=collect(1:2nx))
17-
conT = Constraint(cT, nx, 0, 0)
16+
cont = Constraint(ct, nx, nu, nw=nw, idx_ineq=collect(1:2nx))
17+
conT = Constraint(cT, nx, 0, nw=0)
1818

1919
cons = [[cont for t = 1:T-1]..., conT]
2020
nc = DTO.num_con(cons)

0 commit comments

Comments
 (0)