Skip to content
2 changes: 1 addition & 1 deletion CondaPkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
pandas = ""
matplotlib = ""
xarray = ""
arviz = ">=0.15.0,<=0.18"
arviz = ">=0.15.0,<=0.22"
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ArviZPythonPlots"
uuid = "4a6e88f0-2c8e-11ee-0601-e94153f0eada"
authors = ["Seth Axen <[email protected]>"]
version = "0.1.12"
version = "0.1.13"

[deps]
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
Expand Down
4 changes: 4 additions & 0 deletions src/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ function PythonCall.Py(d::PSISLOOResult)
psis_result = d.psis_result
ds = convert_to_dataset((loo_i=pointwise.elpd, pareto_shape=pointwise.pareto_shape))
pyds = PythonCall.Py(ds)
n_samples = d.psis_result.ndraws * d.psis_result.nchains
good_k = min(1 - inv(log10(n_samples)), 0.7)

entries = (
elpd_loo=estimates.elpd,
se=estimates.se_elpd,
Expand All @@ -14,6 +17,7 @@ function PythonCall.Py(d::PSISLOOResult)
loo_i=pyds.loo_i,
pareto_k=pyds.pareto_shape,
scale="log",
good_k=good_k,
)
data = pylist(values(entries))
index = pylist(map(pystr, keys(entries)))
Expand Down
4 changes: 2 additions & 2 deletions test/test_conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ using Test
loo_result = loo(idata; reff=1)
loo_py_result = ArviZPythonPlots.arviz.loo(idata; pointwise=true, reff=1)
py_loo_result = Py(loo_result)
@test all(
pyconvert(Array{String}, py_loo_result.keys()) ==
@test issubset(
pyconvert(Array{String}, loo_py_result.keys()),
pyconvert(Array{String}, py_loo_result.keys()),
)
@test pyconvert(Float64, py_loo_result.elpd_loo) ≈
pyconvert(Float64, loo_py_result.elpd_loo) rtol = 1e-3
Expand Down
Loading