Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMCMCChainsExt = ["MCMCChains", "Statistics"]
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
DynamicPPLMooncakeExt = ["Mooncake"]

Expand Down
39 changes: 39 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module DynamicPPLMCMCChainsExt

using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
using MCMCChains: MCMCChains
using Statistics: mean

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names

Expand Down Expand Up @@ -140,6 +141,44 @@ function AbstractMCMC.to_samples(
end
end

function AbstractMCMC.bundle_samples(
ts::Vector{<:DynamicPPL.ParamsWithStats},
model::DynamicPPL.Model,
spl::AbstractMCMC.AbstractSampler,
state,
chain_type::Type{MCMCChains.Chains};
save_state=false,
stats=missing,
sort_chain=false,
discard_initial=0,
thinning=1,
kwargs...,
)
# Construct the 'bare' chain first
bare_chain = AbstractMCMC.from_samples(MCMCChains.Chains, reshape(ts, :, 1))

# Add additional MCMC-specific info
info = bare_chain.info
if save_state
info = merge(info, (model=model, sampler=spl, samplerstate=state))
end
if !ismissing(stats)
info = merge(info, (start_time=stats.start, stop_time=stats.stop))
end

# Reconstruct the chain with the extra information
# Yeah, this is quite ugly. Blame MCMCChains.
chain = MCMCChains.Chains(
bare_chain.value.data,
names(bare_chain),
bare_chain.name_map;
info=info,
start=discard_initial + 1,
thin=thinning,
)
return sort_chain ? sort(chain) : chain
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,14 @@ include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
include("experimental.jl")
include("chains.jl")
include("bijector.jl")

include("debug_utils.jl")
using .DebugUtils
include("test_utils.jl")

include("experimental.jl")
include("deprecated.jl")

if isdefined(Base.Experimental, :register_error_hint)
Expand Down
57 changes: 57 additions & 0 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,60 @@ function ParamsWithStats(
end
return ParamsWithStats(params, stats)
end

"""
ParamsWithStats(
param_vector::AbstractVector,
ldf::DynamicPPL.Experimental.FastLDF,
stats::NamedTuple=NamedTuple();
include_colon_eq::Bool=true,
include_log_probs::Bool=true,
)

Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided
`param_vector`.

This method is intended to replace the old method of obtaining parameters and statistics
via `unflatten` plus re-evaluation. It is faster for two reasons:

1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as
otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent
MCMC iterations).
2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`.
"""
function ParamsWithStats(
param_vector::AbstractVector,
ldf::DynamicPPL.Experimental.FastLDF,
stats::NamedTuple=NamedTuple();
include_colon_eq::Bool=true,
include_log_probs::Bool=true,
)
strategy = InitFromParams(
VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector),
nothing,
)
accs = if include_log_probs
(
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),
)
else
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
end
_, vi = DynamicPPL.Experimental.fast_evaluate!!(
ldf.model, strategy, AccumulatorTuple(accs)
)
params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
if include_log_probs
stats = merge(
stats,
(
logprior=DynamicPPL.getlogprior(vi),
loglikelihood=DynamicPPL.getloglikelihood(vi),
lp=DynamicPPL.getlogjoint(vi),
),
)
end
return ParamsWithStats(params, stats)
end
81 changes: 58 additions & 23 deletions src/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using DynamicPPL:
AccumulatorTuple,
InitContext,
InitFromParams,
AbstractInitStrategy,
LogJacobianAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
Expand All @@ -28,6 +29,60 @@ using LogDensityProblems: LogDensityProblems
import DifferentiationInterface as DI
using Random: Random

"""
DynamicPPL.Experimental.fast_evaluate!!(
[rng::Random.AbstractRNG,]
model::Model,
strategy::AbstractInitStrategy,
accs::AccumulatorTuple, params::AbstractVector{<:Real}
)

Evaluate a model using parameters obtained via `strategy`, and only computing the results in
the provided accumulators.

It is assumed that the accumulators passed in have been initialised to appropriate values,
as this function will not reset them. The default constructors for each accumulator will do
this for you correctly.

Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs`
argument may be mutated (depending on how the accumulators are implemented); hence the `!!`
in the function name.
"""
@inline function fast_evaluate!!(
# Note that this `@inline` is mandatory for performance. If it's not inlined, it leads
# to extra allocations (even for trivial models) and much slower runtime.
rng::Random.AbstractRNG,
model::Model,
strategy::AbstractInitStrategy,
accs::AccumulatorTuple,
)
ctx = InitContext(rng, strategy)
model = DynamicPPL.setleafcontext(model, ctx)
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
# here.
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
# it _should_ do, but this is wrong regardless.
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
vi = if Threads.nthreads() > 1
param_eltype = DynamicPPL.get_param_eltype(strategy)
accs = map(accs) do acc
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
end
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
else
OnlyAccsVarInfo(accs)
end
return DynamicPPL._evaluate!!(model, vi)
end
@inline function fast_evaluate!!(
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
)
# This `@inline` is also mandatory for performance
return fast_evaluate!!(Random.default_rng(), model, strategy, accs)
end

"""
FastLDF(
model::Model,
Expand Down Expand Up @@ -213,31 +268,11 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
varname_ranges::Dict{VarName,RangeAndLinked}
end
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
ctx = InitContext(
Random.default_rng(),
InitFromParams(
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
),
strategy = InitFromParams(
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
)
model = DynamicPPL.setleafcontext(f.model, ctx)
accs = fast_ldf_accs(f.getlogdensity)
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
# here.
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
# it _should_ do, but this is wrong regardless.
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
vi = if Threads.nthreads() > 1
accs = map(
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
accs,
)
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
else
OnlyAccsVarInfo(accs)
end
_, vi = DynamicPPL._evaluate!!(model, vi)
_, vi = fast_evaluate!!(f.model, strategy, accs)
return f.getlogdensity(vi)
end

Expand Down
28 changes: 27 additions & 1 deletion test/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using DynamicPPL
using Distributions
using Test

@testset "ParamsWithStats" begin
@testset "ParamsWithStats from VarInfo" begin
@model function f(z)
x ~ Normal()
y := x + 1
Expand Down Expand Up @@ -66,4 +66,30 @@ using Test
end
end

@testset "ParamsWithStats from FastLDF" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
unlinked_vi = VarInfo(m)
@testset "$islinked" for islinked in (false, true)
vi = if islinked
DynamicPPL.link!!(unlinked_vi, m)
else
unlinked_vi
end
params = [x for x in vi[:]]

# Get the ParamsWithStats using FastLDF
fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi)
ps = ParamsWithStats(params, fldf)

# Check that length of parameters is as expected
@test length(ps.params) == length(keys(vi))

# Iterate over all variables to check that their values match
for vn in keys(vi)
@test ps.params[vn] == vi[vn]
end
end
end
end

end # module