Skip to content

Commit 6bc0c74

Browse files
nsicchamhauru
andauthored
Implements a simple Nutpie style adaptation (using both positions and gradients, but not changing the schedule). (#473)
* initial changes to get a working demo * fix tests, add new ones, and add documentation * fix type * fix some stray tests * delete tmp folder with demo * address review comments * reference interface refactor issue * refactor and fix test rng handling * improve docstring for NutpieVar * remove superfluous white space * fix JET tests * add entry to history, bump version * fix NutpieVar docstring * increase number of tests for mass matrix adaptation --------- Co-authored-by: Markus Hauru <[email protected]>
1 parent 0dbc2ee commit 6bc0c74

File tree

11 files changed

+224
-55
lines changed

11 files changed

+224
-55
lines changed

HISTORY.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# AdvancedHMC Changelog
22

3+
## 0.8.4
4+
5+
- Introduces an experimental way to improve the *diagonal* mass matrix adaptation using gradient information (similar to [nutpie](https://github.com/pymc-devs/nutpie)),
6+
currently to be initialized for a `metric` of type `DiagEuclideanMetric`
7+
via `mma = AdvancedHMC.NutpieVar(size(metric); var=copy(metric.M⁻¹))`
8+
until a new interface is introduced in an upcoming breaking release to specify the method of adaptation.
9+
310
## 0.8.0
411

512
- To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`).

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.8.3"
3+
version = "0.8.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/src/api.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,15 @@ where `ϵ` is the step size of leapfrog integration.
3131
### Adaptor (`adaptor`)
3232

3333
- Adapt the mass matrix `metric` of the Hamiltonian dynamics: `mma = MassMatrixAdaptor(metric)`
34-
34+
3535
+ This is lowered to `UnitMassMatrix`, `WelfordVar` or `WelfordCov` based on the type of the mass matrix `metric`
36+
+ There is an experimental way to improve the *diagonal* mass matrix adaptation using gradient information (similar to [nutpie](https://github.com/pymc-devs/nutpie)),
37+
currently to be initialized for a `metric` of type `DiagEuclideanMetric`
38+
via `mma = AdvancedHMC.NutpieVar(size(metric); var=copy(metric.M⁻¹))`
39+
until a new interface is introduced in an upcoming breaking release to specify the method of adaptation.
3640

3741
- Adapt the step size of the leapfrog integrator `integrator`: `ssa = StepSizeAdaptor(δ, integrator)`
38-
42+
3943
+ It uses Nesterov's dual averaging with `δ` as the target acceptance rate.
4044
- Combine the two above *naively*: `NaiveHMCAdaptor(mma, ssa)`
4145
- Combine the first two using Stan's windowed adaptation: `StanHMCAdaptor(mma, ssa)`
@@ -60,12 +64,12 @@ sample(
6064
Draw `n_samples` samples using the kernel `κ` under the Hamiltonian system `h`
6165

6266
- The randomness is controlled by `rng`.
63-
67+
6468
+ If `rng` is not provided, the default random number generator (`Random.default_rng()`) will be used.
6569

6670
- The initial point is given by `θ`.
6771
- The adaptor is set by `adaptor`, for which the default is no adaptation.
68-
72+
6973
+ It will perform `n_adapts` steps of adaptation, for which the default is `1_000` or 10% of `n_samples`, whichever is lower.
7074
- `drop_warmup` specifies whether to drop samples.
7175
- `verbose` controls the verbosity.

src/AdvancedHMC.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ export find_good_eps
7272
include("adaptation/Adaptation.jl")
7373
using .Adaptation
7474
import .Adaptation:
75-
StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation
75+
StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation, PositionOrPhasePoint
7676

7777
# Helpers for initializing adaptors via AHMC structs
7878

@@ -114,6 +114,7 @@ export StepSizeAdaptor,
114114
MassMatrixAdaptor,
115115
UnitMassMatrix,
116116
WelfordVar,
117+
NutpieVar,
117118
WelfordCov,
118119
NaiveHMCAdaptor,
119120
StanHMCAdaptor,

src/abstractmcmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ function AbstractMCMC.step(
196196

197197
# Adapt h and spl.
198198
tstat = stat(t)
199-
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
199+
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z, tstat.acceptance_rate)
200200
tstat = merge(tstat, (is_adapt=isadapted,))
201201

202202
# Compute next transition and state.

src/adaptation/Adaptation.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ export Adaptation
44
using LinearAlgebra: LinearAlgebra
55
using Statistics: Statistics
66

7-
using ..AdvancedHMC: AbstractScalarOrVec
7+
using ..AdvancedHMC: AbstractScalarOrVec, PhasePoint
88
using DocStringExtensions
99

1010
"""
1111
$(TYPEDEF)
1212
13-
Abstract type for HMC adaptors.
13+
Abstract type for HMC adaptors.
1414
"""
1515
abstract type AbstractAdaptor end
1616
function getM⁻¹ end
@@ -21,12 +21,17 @@ function initialize! end
2121
function finalize! end
2222
export AbstractAdaptor, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹
2323

24+
get_position(x::PhasePoint) = x.θ
25+
get_position(x::AbstractVecOrMat{<:AbstractFloat}) = x
26+
const PositionOrPhasePoint = Union{AbstractVecOrMat{<:AbstractFloat}, PhasePoint}
27+
2428
struct NoAdaptation <: AbstractAdaptor end
2529
export NoAdaptation
2630
include("stepsize.jl")
2731
export StepSizeAdaptor, NesterovDualAveraging
32+
2833
include("massmatrix.jl")
29-
export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, WelfordCov
34+
export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, NutpieVar, WelfordCov
3035

3136
##
3237
## Composite adaptors
@@ -47,18 +52,14 @@ getϵ(ca::NaiveHMCAdaptor) = getϵ(ca.ssa)
4752
# TODO: implement consensus adaptor
4853
function adapt!(
4954
nca::NaiveHMCAdaptor,
50-
θ::AbstractVecOrMat{<:AbstractFloat},
55+
z_or_theta::PositionOrPhasePoint,
5156
α::AbstractScalarOrVec{<:AbstractFloat},
5257
)
53-
adapt!(nca.ssa, θ, α)
54-
adapt!(nca.pc, θ, α)
55-
return nothing
56-
end
57-
function reset!(aca::NaiveHMCAdaptor)
58-
reset!(aca.ssa)
59-
reset!(aca.pc)
58+
adapt!(nca.ssa, z_or_theta, α)
59+
adapt!(nca.pc, z_or_theta, α)
6060
return nothing
6161
end
62+
6263
initialize!(adaptor::NaiveHMCAdaptor, n_adapts::Int) = nothing
6364
finalize!(aca::NaiveHMCAdaptor) = finalize!(aca.ssa)
6465

src/adaptation/massmatrix.jl

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@ finalize!(::MassMatrixAdaptor) = nothing
99

1010
function adapt!(
1111
adaptor::MassMatrixAdaptor,
12-
θ::AbstractVecOrMat{<:AbstractFloat},
13-
α::AbstractScalarOrVec{<:AbstractFloat},
12+
z_or_theta::PositionOrPhasePoint,
13+
::AbstractScalarOrVec{<:AbstractFloat},
1414
is_update::Bool=true,
1515
)
16-
resize_adaptor!(adaptor, size(θ))
17-
push!(adaptor, θ)
16+
resize_adaptor!(adaptor, size(get_position(z_or_theta)))
17+
push!(adaptor, z_or_theta)
1818
is_update && update!(adaptor)
1919
return nothing
2020
end
2121

22+
Base.push!(a::MassMatrixAdaptor, z_or_theta::PositionOrPhasePoint) = push!(a, get_position(z_or_theta))
23+
2224
## Unit mass matrix adaptor
2325

2426
struct UnitMassMatrix{T<:AbstractFloat} <: MassMatrixAdaptor end
@@ -39,15 +41,14 @@ getM⁻¹(::UnitMassMatrix{T}) where {T} = LinearAlgebra.UniformScaling{T}(one(T
3941

4042
function adapt!(
4143
::UnitMassMatrix,
42-
::AbstractVecOrMat{<:AbstractFloat},
44+
::PositionOrPhasePoint,
4345
::AbstractScalarOrVec{<:AbstractFloat},
4446
is_update::Bool=true,
4547
)
4648
return nothing
4749
end
4850

4951
## Diagonal mass matrix adaptor
50-
5152
abstract type DiagMatrixEstimator{T} <: MassMatrixAdaptor end
5253

5354
getM⁻¹(ve::DiagMatrixEstimator) = ve.var
@@ -70,7 +71,7 @@ NaiveVar{T}(sz::Tuple{Int,Int}) where {T<:AbstractFloat} = NaiveVar(Vector{Matri
7071

7172
NaiveVar(sz::Union{Tuple{Int},Tuple{Int,Int}}) = NaiveVar{Float64}(sz)
7273

73-
Base.push!(nv::NaiveVar, s::AbstractVecOrMat) = push!(nv.S, s)
74+
Base.push!(nv::NaiveVar, s::AbstractVecOrMat{<:AbstractFloat}) = push!(nv.S, s)
7475

7576
reset!(nv::NaiveVar) = resize!(nv.S, 0)
7677

@@ -135,7 +136,7 @@ function reset!(wv::WelfordVar{T}) where {T<:AbstractFloat}
135136
return nothing
136137
end
137138

138-
function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T}
139+
function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T<:AbstractFloat}
139140
wv.n += 1
140141
(; δ, μ, M, n) = wv
141142
n = T(n)
@@ -153,6 +154,90 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat}
153154
return n / ((n + 5) * (n - 1)) * M .+ ϵ * (5 / (n + 5))
154155
end
155156

157+
"""
158+
NutpieVar
159+
160+
Nutpie-style diagonal mass matrix estimator (using positions and gradients).
161+
162+
Expected to converge faster and to a better mass matrix than [`WelfordVar`](@ref), for which it is a drop-in replacement.
163+
164+
Can be initialized via `NutpieVar(sz)` where `sz` is either a `Tuple{Int}` or a `Tuple{Int,Int}`.
165+
166+
# Fields
167+
168+
$(FIELDS)
169+
"""
170+
mutable struct NutpieVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T}
171+
"Online variance estimator of the posterior positions."
172+
position_estimator::WelfordVar{T,E,V}
173+
"Online variance estimator of the posterior gradients."
174+
gradient_estimator::WelfordVar{T,E,V}
175+
"The number of observations collected so far."
176+
n::Int
177+
"The minimal number of observations after which the estimate of the variances can be updated."
178+
n_min::Int
179+
"The estimated variances - initialized to ones, updated after calling [`update!`](@ref) if `n > n_min`."
180+
var::V
181+
function NutpieVar(n::Int, n_min::Int, μ::E, M::E, δ::E, var::V) where {E,V}
182+
return new{eltype(E),E,V}(
183+
WelfordVar(n, n_min, copy(μ), copy(M), copy(δ), copy(var)),
184+
WelfordVar(n, n_min, copy(μ), copy(M), copy(δ), copy(var)),
185+
n, n_min, var
186+
)
187+
end
188+
end
189+
190+
function Base.show(io::IO, ::NutpieVar{T}) where {T}
191+
return print(io, "NutpieVar{", T, "} adaptor")
192+
end
193+
194+
function NutpieVar{T}(
195+
sz::Union{Tuple{Int},Tuple{Int,Int}}=(2,); n_min::Int=10, var=ones(T, sz)
196+
) where {T<:AbstractFloat}
197+
return NutpieVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var)
198+
end
199+
200+
function NutpieVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...)
201+
return NutpieVar{Float64}(sz; kwargs...)
202+
end
203+
204+
function resize_adaptor!(nv::NutpieVar{T}, size_θ::Tuple{Int,Int}) where {T<:AbstractFloat}
205+
if size_θ != size(nv.var)
206+
@assert nv.n == 0 "Cannot resize a var estimator when it contains samples."
207+
resize_adaptor!(nv.position_estimator, size_θ)
208+
resize_adaptor!(nv.gradient_estimator, size_θ)
209+
nv.var = ones(T, size_θ)
210+
end
211+
end
212+
213+
function resize_adaptor!(nv::NutpieVar{T}, size_θ::Tuple{Int}) where {T<:AbstractFloat}
214+
length_θ = first(size_θ)
215+
if length_θ != size(nv.var, 1)
216+
@assert nv.n == 0 "Cannot resize a var estimator when it contains samples."
217+
resize_adaptor!(nv.position_estimator, size_θ)
218+
resize_adaptor!(nv.gradient_estimator, size_θ)
219+
fill!(resize!(nv.var, length_θ), T(1))
220+
end
221+
end
222+
223+
function reset!(nv::NutpieVar)
224+
nv.n = 0
225+
reset!(nv.position_estimator)
226+
reset!(nv.gradient_estimator)
227+
end
228+
229+
Base.push!(::NutpieVar, x::AbstractVecOrMat{<:AbstractFloat}) = error("`NutpieVar` adaptation requires position and gradient information!")
230+
231+
function Base.push!(nv::NutpieVar, z::PhasePoint)
232+
nv.n += 1
233+
push!(nv.position_estimator, z.θ)
234+
push!(nv.gradient_estimator, z.ℓπ.gradient)
235+
return nothing
236+
end
237+
238+
# Ref: https://github.com/pymc-devs/nutpie
239+
get_estimation(nv::NutpieVar) = sqrt.(get_estimation(nv.position_estimator) ./ get_estimation(nv.gradient_estimator))
240+
156241
## Dense mass matrix adaptor
157242

158243
abstract type DenseMatrixEstimator{T} <: MassMatrixAdaptor end
@@ -175,7 +260,7 @@ end
175260

176261
NaiveCov{T}(sz::Tuple{Int}) where {T<:AbstractFloat} = NaiveCov(Vector{Vector{T}}())
177262

178-
Base.push!(nc::NaiveCov, s::AbstractVector) = push!(nc.S, s)
263+
Base.push!(nc::NaiveCov, s::AbstractVector{<:AbstractFloat}) = push!(nc.S, s)
179264

180265
reset!(nc::NaiveCov{T}) where {T} = resize!(nc.S, 0)
181266

@@ -225,7 +310,7 @@ function reset!(wc::WelfordCov{T}) where {T<:AbstractFloat}
225310
return nothing
226311
end
227312

228-
function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T}
313+
function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T<:AbstractFloat}
229314
wc.n += 1
230315
(; δ, μ, n, M) = wc
231316
n = T(n)

src/adaptation/stan_adaptor.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,20 +136,20 @@ is_window_end(a::StanHMCAdaptor) = a.state.i in a.state.window_splits
136136

137137
function adapt!(
138138
tp::StanHMCAdaptor,
139-
θ::AbstractVecOrMat{<:AbstractFloat},
139+
z_or_theta::PositionOrPhasePoint,
140140
α::AbstractScalarOrVec{<:AbstractFloat},
141141
)
142142
tp.state.i += 1
143143

144-
adapt!(tp.ssa, θ, α)
144+
adapt!(tp.ssa, z_or_theta, α)
145145

146-
resize_adaptor!(tp.pc, size(θ)) # Resize pre-conditioner if necessary.
146+
resize_adaptor!(tp.pc, size(get_position(z_or_theta))) # Resize pre-conditioner if necessary.
147147

148148
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
149149
if is_in_window(tp)
150150
# We accumlate stats from θ online and only trigger the update of M⁻¹ in the end of window.
151151
is_update_M⁻¹ = is_window_end(tp)
152-
adapt!(tp.pc, θ, α, is_update_M⁻¹)
152+
adapt!(tp.pc, z_or_theta, α, is_update_M⁻¹)
153153
end
154154

155155
if is_window_end(tp)

src/adaptation/stepsize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ end
174174
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/stepsize_adaptation.hpp
175175
# Note: This function is not merged with `adapt!` to empahsize the fact that
176176
# step size adaptation is not dependent on `θ`.
177-
# Note 2: `da.state` and `α` support vectorised HMC but should do so together.
177+
# Note 2: `da.state` and `α` support vectorised HMC but should do so together.
178178
function adapt_stepsize!(
179179
da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{T}
180180
) where {T<:AbstractFloat}
@@ -211,7 +211,7 @@ end
211211

212212
function adapt!(
213213
da::NesterovDualAveraging,
214-
θ::AbstractVecOrMat{<:AbstractFloat},
214+
::PositionOrPhasePoint,
215215
α::AbstractScalarOrVec{<:AbstractFloat},
216216
)
217217
adapt_stepsize!(da, α)

0 commit comments

Comments
 (0)