|
1 | | -# TODO: Move to Bijectors.jl if we find further use for this. |
2 | | -""" |
3 | | - wrap_in_vec_reshape(f, in_size) |
4 | | -
|
5 | | -Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces |
6 | | -a vector of length `prod(Bijectors.output(f, in_size))`. |
7 | | -""" |
8 | | -function wrap_in_vec_reshape(f, in_size) |
9 | | - vec_in_length = prod(in_size) |
10 | | - reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) |
11 | | - out_size = Bijectors.output_size(f, in_size) |
12 | | - vec_out_length = prod(out_size) |
13 | | - reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) |
14 | | - return reshape_outer ∘ f ∘ reshape_inner |
15 | | -end |
16 | | - |
17 | | -""" |
18 | | - bijector(model::Model[, sym2ranges = Val(false)]) |
19 | | -
|
20 | | -Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` |
21 | | -denoting the dimensionality of the latent variables. |
22 | | -""" |
23 | | -function Bijectors.bijector( |
24 | | - model::DynamicPPL.Model, ::Val{sym2ranges}=Val(false); varinfo=DynamicPPL.VarInfo(model) |
25 | | -) where {sym2ranges} |
26 | | - num_params = sum([ |
27 | | - size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) |
28 | | - ]) |
29 | | - |
30 | | - dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) |
31 | | - |
32 | | - num_ranges = sum([ |
33 | | - length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) |
34 | | - ]) |
35 | | - ranges = Vector{UnitRange{Int}}(undef, num_ranges) |
36 | | - idx = 0 |
37 | | - range_idx = 1 |
38 | | - |
39 | | - # ranges might be discontinuous => values are vectors of ranges rather than just ranges |
40 | | - sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() |
41 | | - for sym in keys(varinfo.metadata) |
42 | | - sym_lookup[sym] = Vector{UnitRange{Int}}() |
43 | | - for r in varinfo.metadata[sym].ranges |
44 | | - ranges[range_idx] = idx .+ r |
45 | | - push!(sym_lookup[sym], ranges[range_idx]) |
46 | | - range_idx += 1 |
47 | | - end |
48 | | - |
49 | | - idx += varinfo.metadata[sym].ranges[end][end] |
50 | | - end |
51 | | - |
52 | | - bs = map(tuple(dists...)) do d |
53 | | - b = Bijectors.bijector(d) |
54 | | - if d isa Distributions.UnivariateDistribution |
55 | | - b |
56 | | - else |
57 | | - wrap_in_vec_reshape(b, size(d)) |
58 | | - end |
59 | | - end |
60 | | - |
61 | | - if sym2ranges |
62 | | - return ( |
63 | | - Bijectors.Stacked(bs, ranges), |
64 | | - (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), |
65 | | - ) |
66 | | - else |
67 | | - return Bijectors.Stacked(bs, ranges) |
68 | | - end |
69 | | -end |
70 | | - |
71 | 1 | """ |
72 | 2 | meanfield([rng, ]model::Model) |
73 | 3 |
|
|
0 commit comments