diff --git a/CondaPkg.toml b/CondaPkg.toml index d485f00..3feb6b3 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -2,4 +2,4 @@ pandas = "" matplotlib = "" xarray = "" -arviz = ">=0.15.0,<=0.18" +arviz = ">=0.15.0,<=0.22" diff --git a/Project.toml b/Project.toml index 530f9a2..bcfcd16 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ArviZPythonPlots" uuid = "4a6e88f0-2c8e-11ee-0601-e94153f0eada" authors = ["Seth Axen "] -version = "0.1.12" +version = "0.1.13" [deps] CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" diff --git a/src/conversions.jl b/src/conversions.jl index 185845f..66febda 100644 --- a/src/conversions.jl +++ b/src/conversions.jl @@ -4,6 +4,9 @@ function PythonCall.Py(d::PSISLOOResult) psis_result = d.psis_result ds = convert_to_dataset((loo_i=pointwise.elpd, pareto_shape=pointwise.pareto_shape)) pyds = PythonCall.Py(ds) + n_samples = d.psis_result.ndraws * d.psis_result.nchains + good_k = min(1 - inv(log10(n_samples)), 0.7) + entries = ( elpd_loo=estimates.elpd, se=estimates.se_elpd, @@ -14,6 +17,7 @@ function PythonCall.Py(d::PSISLOOResult) loo_i=pyds.loo_i, pareto_k=pyds.pareto_shape, scale="log", + good_k=good_k, ) data = pylist(values(entries)) index = pylist(map(pystr, keys(entries))) diff --git a/test/test_conversions.jl b/test/test_conversions.jl index 9a7a308..4a08af0 100644 --- a/test/test_conversions.jl +++ b/test/test_conversions.jl @@ -10,9 +10,9 @@ using Test loo_result = loo(idata; reff=1) loo_py_result = ArviZPythonPlots.arviz.loo(idata; pointwise=true, reff=1) py_loo_result = Py(loo_result) - @test all( - pyconvert(Array{String}, py_loo_result.keys()) == + @test issubset( pyconvert(Array{String}, loo_py_result.keys()), + pyconvert(Array{String}, py_loo_result.keys()), ) @test pyconvert(Float64, py_loo_result.elpd_loo) ≈ pyconvert(Float64, loo_py_result.elpd_loo) rtol = 1e-3