Skip to content

Commit 3efc60a

Browse files
authored
predict for Disitributions (#295)
* `predict` for Disitributions * add tests * bump version * small update to test
1 parent e104960 commit 3efc60a

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Soss"
22
uuid = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
33
author = ["Chad Scherrer <[email protected]>"]
4-
version = "0.20.2"
4+
version = "0.20.3"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/transforms/predict.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using TupleVectors
44
predict(m::AbstractModel, args...) = predict(Random.GLOBAL_RNG, m, args...)
55
predict(d::AbstractMeasure, x) = x
66

7+
predict(d::Dists.Distribution, x) = x
78

89
@inline function predict(rng::AbstractRNG, m::AbstractModel, nt::NamedTuple{N}) where {N}
910
pred = predictive(Model(m), N...)

test/runtests.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,5 +132,22 @@ include("examples-list.jl")
132132

133133
@test transform(xform(post), randn(6)) isa NamedTuple
134134

135+
@testset "logdensity" begin
136+
dat = randn(100)
137+
m = Soss.@model n begin
138+
μ ~ Dists.Normal()
139+
σ ~ Dists.Exponential()
140+
data ~ Dists.Normal(μ, σ) |> iid(n)
141+
return (; data)
142+
end
143+
mod = m( (; n = length(dat) ) )
144+
post = mod | (data = dat,)
145+
146+
@test logdensity( mod( (μ = 1., σ = 2., data = dat) ) ) == logdensity( post( (μ = 1., σ = 2.) ) )
147+
end
148+
149+
135150
end
151+
152+
136153
end

0 commit comments

Comments
 (0)