Skip to content

Commit 24e6edd

Browse files
committed
Enable ECE with vector of scalars (#109)
bors r+ Co-authored-by: David Widmann <[email protected]>
1 parent 9be8eef commit 24e6edd

File tree

9 files changed

+176
-29
lines changed

9 files changed

+176
-29
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CalibrationErrors"
22
uuid = "33913031-fe46-5864-950f-100836f47845"
33
authors = ["David Widmann <[email protected]>"]
4-
version = "0.5.20"
4+
version = "0.5.21"
55

66
[deps]
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

src/CalibrationErrors.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ using Reexport
55
using DataStructures
66
@reexport using Distances
77
@reexport using KernelFunctions
8-
using StatsBase
9-
using UnPack
8+
using StatsBase: StatsBase
9+
using UnPack: @unpack
1010

1111
using LinearAlgebra
1212
using Statistics

src/binning/generic.jl

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
abstract type AbstractBinningAlgorithm end
22

3-
mutable struct Bin{T<:Real}
3+
mutable struct Bin{T}
44
"""Number of samples."""
55
nsamples::Int
66
"""Mean of predictions."""
7-
mean_predictions::Vector{T}
7+
mean_predictions::T
88
"""Proportions of targets."""
9-
proportions_targets::Vector{T}
9+
proportions_targets::T
1010

1111
function Bin{T}(
12-
nsamples::Int, mean_predictions::Vector{T}, proportions_targets::Vector{T}
13-
) where {T}
12+
nsamples::Int, mean_predictions::T, proportions_targets::T
13+
) where {T<:Real}
14+
nsamples 0 || throw(ArgumentError("the number of samples must be non-negative"))
15+
return new{T}(nsamples, mean_predictions, proportions_targets)
16+
end
17+
18+
function Bin{T}(
19+
nsamples::Int, mean_predictions::T, proportions_targets::T
20+
) where {T<:AbstractVector{<:Real}}
1421
nsamples 0 || throw(ArgumentError("the number of samples must be non-negative"))
1522
nclasses = length(mean_predictions)
1623
nclasses > 1 || throw(ArgumentError("the number of classes must be greater than 1"))
@@ -24,9 +31,7 @@ mutable struct Bin{T<:Real}
2431
end
2532
end
2633

27-
function Bin(
28-
nsamples::Int, mean_predictions::Vector{T}, proportions_targets::Vector{T}
29-
) where {T<:Real}
34+
function Bin(nsamples::Int, mean_predictions::T, proportions_targets::T) where {T}
3035
return Bin{T}(nsamples, mean_predictions, proportions_targets)
3136
end
3237

@@ -35,6 +40,15 @@ end
3540
3641
Create bin of `predictions` and corresponding `targets`.
3742
"""
43+
function Bin(predictions::AbstractVector{<:Real}, targets::AbstractVector{Bool})
44+
# compute mean of predictions
45+
mean_predictions = mean(predictions)
46+
47+
# compute proportion of targets
48+
proportions_targets = mean(targets)
49+
50+
return Bin(length(predictions), mean_predictions, proportions_targets)
51+
end
3852
function Bin(
3953
predictions::AbstractVector{<:AbstractVector{<:Real}},
4054
targets::AbstractVector{<:Integer},
@@ -44,16 +58,26 @@ function Bin(
4458

4559
# compute proportion of targets
4660
nclasses = length(predictions[1])
47-
proportions_targets = proportions(targets, nclasses)
61+
proportions_targets = StatsBase.proportions(targets, nclasses)
4862

4963
return Bin(length(predictions), mean_predictions, proportions_targets)
5064
end
5165

5266
"""
5367
Bin(prediction, target)
5468
55-
Create bin of a signle `prediction` and corresponding `target`.
69+
Create bin of a single `prediction` and corresponding `target`.
5670
"""
71+
function Bin(prediction::Real, target::Bool)
72+
# compute mean of predictions
73+
mean_predictions = prediction / 1
74+
75+
# compute proportion of targets
76+
proportions_targets = target / 1
77+
78+
return Bin(1, mean_predictions, proportions_targets)
79+
end
80+
5781
function Bin(prediction::AbstractVector{<:Real}, target::Integer)
5882
# compute mean of predictions
5983
mean_predictions = prediction ./ 1
@@ -73,6 +97,22 @@ end
7397
Update running statistics of the `bin` by integrating one additional pair of `prediction`s
7498
and `target`.
7599
"""
100+
function adddata!(bin::Bin, prediction::Real, target::Bool)
101+
@unpack mean_predictions, proportions_targets = bin
102+
103+
# update number of samples
104+
nsamples = (bin.nsamples += 1)
105+
106+
# update mean of predictions
107+
mean_predictions += (prediction - mean_predictions) / nsamples
108+
bin.mean_predictions = mean_predictions
109+
110+
# update proportions of targets
111+
proportions_targets += (target - proportions_targets) / nsamples
112+
bin.proportions_targets = proportions_targets
113+
114+
return nothing
115+
end
76116
function adddata!(bin::Bin, prediction::AbstractVector{<:Real}, target::Integer)
77117
@unpack mean_predictions, proportions_targets = bin
78118

src/binning/medianvariance.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ MedianVarianceBinning(minsize::Int=10) = MedianVarianceBinning(minsize, typemax(
2525

2626
function perform(
2727
alg::MedianVarianceBinning,
28-
predictions::AbstractVector{<:AbstractVector{T}},
28+
predictions::AbstractVector{<:AbstractVector{<:Real}},
2929
targets::AbstractVector{<:Integer},
30-
) where {T<:Real}
30+
)
3131
@unpack minsize, maxbins = alg
3232

3333
# check if binning is not possible
@@ -50,8 +50,7 @@ function perform(
5050
(idxs_predictions, argmax_var_predictions) => max_var_predictions,
5151
Base.Order.Reverse,
5252
)
53-
S = typeof(zero(T) / 1)
54-
bins = Vector{Bin{S}}(undef, 0)
53+
bins = Vector{typeof(Bin(predictions, targets))}(undef, 0)
5554

5655
nbins = 1
5756
while nbins < maxbins && !isempty(queue)

src/binning/uniform.jl

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,41 @@ end
1616

1717
function perform(
1818
binning::UniformBinning,
19-
predictions::AbstractVector{<:AbstractVector{T}},
19+
predictions::AbstractVector{<:Real},
20+
targets::AbstractVector{Bool},
21+
)
22+
@unpack nbins = binning
23+
24+
# create dictionary of bins
25+
T = eltype(float(zero(eltype(predictions))))
26+
bins = Dict{Int,Bin{T}}()
27+
28+
# reserve some memory (very rough guess)
29+
nsamples = length(predictions)
30+
sizehint!(bins, min(nbins, nsamples))
31+
32+
# for all other samples
33+
@inbounds for (prediction, target) in zip(predictions, targets)
34+
# compute index of bin
35+
index = binindex(prediction, nbins)
36+
37+
# create new bin or update existing one
38+
bin = get(bins, index, nothing)
39+
if bin === nothing
40+
bins[index] = Bin(prediction, target)
41+
else
42+
adddata!(bin, prediction, target)
43+
end
44+
end
45+
46+
return values(bins)
47+
end
48+
49+
function perform(
50+
binning::UniformBinning,
51+
predictions::AbstractVector{<:AbstractVector{<:Real}},
2052
targets::AbstractVector{<:Integer},
21-
) where {T<:Real}
53+
)
2254
return _perform(binning, predictions, targets, Val(length(predictions[1])))
2355
end
2456

@@ -36,8 +68,9 @@ function _perform(
3668

3769
# reserve some memory (very rough guess)
3870
nsamples = length(predictions)
39-
sizehint!(bins, nsamples)
40-
sizehint!(binindices, nsamples)
71+
guess = min(nbins, nsamples)
72+
sizehint!(bins, guess)
73+
sizehint!(binindices, guess)
4174

4275
# for all other samples
4376
@inbounds for i in 2:nsamples

src/ece.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,11 @@ In particular, distance measures of the package
3131
ECE(binning::AbstractBinningAlgorithm) = ECE(binning, TotalVariation())
3232

3333
# estimate ECE
34-
function (ece::ECE)(
35-
predictions::AbstractVector{<:AbstractVector{<:Real}},
36-
targets::AbstractVector{<:Integer},
37-
)
34+
function (ece::ECE)(predictions::AbstractVector, targets::AbstractVector)
3835
@unpack binning, distance = ece
3936

4037
# check number of samples
41-
nsamples = check_nsamples(predictions, targets)
38+
check_nsamples(predictions, targets)
4239

4340
# bin predictions and labels
4441
bins = perform(binning, predictions, targets)
@@ -50,17 +47,20 @@ function (ece::ECE)(
5047

5148
# evaluate the distance in the first bin
5249
@inbounds begin
53-
bin = bins[1]
50+
bin, state = iterate(bins) # there is always at least one bin
5451
x = distance(bin.mean_predictions, bin.proportions_targets)
5552

5653
# initialize the estimate
5754
estimate = x / 1
5855

5956
# for all other bins
6057
n = bin.nsamples
61-
for i in 2:nbins
58+
while true
59+
bin_state = iterate(bins, state)
60+
bin_state === nothing && break
61+
6262
# evaluate the distance
63-
bin = bins[i]
63+
bin, state = bin_state
6464
x = distance(bin.mean_predictions, bin.proportions_targets)
6565

6666
# update the estimate

test/binning/generic.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,28 @@
11
@testset "generic.jl" begin
2+
@testset "Simple example" begin
3+
# sample predictions and outcomes
4+
nsamples = 1_000
5+
predictions = rand(nsamples)
6+
outcomes = rand(Bool, nsamples)
7+
8+
# create bin with all predictions and outcomes
9+
bin = CalibrationErrors.Bin(predictions, outcomes)
10+
11+
# check statistics
12+
@test bin.nsamples == nsamples
13+
@test bin.mean_predictions mean(predictions)
14+
@test bin.proportions_targets == mean(outcomes)
15+
16+
# compare with adding data
17+
bin2 = CalibrationErrors.Bin(predictions[1], outcomes[1])
18+
for i in 2:nsamples
19+
CalibrationErrors.adddata!(bin2, predictions[i], outcomes[i])
20+
end
21+
@test bin2.nsamples == bin.nsamples
22+
@test bin2.mean_predictions bin.mean_predictions
23+
@test bin2.proportions_targets bin.proportions_targets
24+
end
25+
226
@testset "Simple example ($nclasses classes)" for nclasses in (2, 10, 100)
327
# sample predictions and targets
428
nsamples = 1_000

test/binning/uniform.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,36 @@
2222
@test_throws ArgumentError CalibrationErrors.binindex([1.5, 0.5], 10, Val(2))
2323
end
2424

25+
@testset "Basic tests" begin
26+
# sample predictions and targets
27+
nsamples = 1_000
28+
predictions = rand(nsamples)
29+
targets = rand(Bool, nsamples)
30+
31+
for nbins in (1, 10, 100, 500, 1_000)
32+
# bin data in bins of uniform width
33+
bins = @inferred(
34+
CalibrationErrors.perform(UniformBinning(nbins), predictions, targets)
35+
)
36+
37+
# check all bins
38+
for bin in bins
39+
# compute index of bin from average prediction
40+
idx = CalibrationErrors.binindex(bin.mean_predictions, nbins)
41+
42+
# compute indices of all predictions in the same bin
43+
idxs = filter(
44+
i -> idx == CalibrationErrors.binindex(predictions[i], nbins),
45+
1:nsamples,
46+
)
47+
48+
@test bin.nsamples == length(idxs)
49+
@test bin.mean_predictions mean(predictions[idxs])
50+
@test bin.proportions_targets mean(targets[idxs])
51+
end
52+
end
53+
end
54+
2555
@testset "Basic tests ($nclasses classes)" for nclasses in (2, 10, 100)
2656
# sample predictions and targets
2757
nsamples = 1_000

test/ece.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
@testset "ece.jl" begin
22
@testset "Trivial tests" begin
33
ece = ECE(UniformBinning(10))
4+
5+
# categorical distributions
46
for predictions in ([[0, 1], [1, 0]], ColVecs([0 1; 1 0]), RowVecs([0 1; 1 0]))
57
@test iszero(@inferred(ece(predictions, [2, 1])))
68
end
@@ -11,12 +13,19 @@
1113
)
1214
@test iszero(@inferred(ece(predictions, [2, 2, 1, 1])))
1315
end
16+
17+
# probabilities
18+
for predictions in ([0, 1], [0.0, 1.0])
19+
@test iszero(@inferred(ece(predictions, [false, true])))
20+
end
21+
@test iszero(@inferred(ece([0, 0.5, 0.5, 1], [false, false, true, true])))
1422
end
1523

1624
@testset "Uniform binning: Basic properties" begin
1725
ece = ECE(UniformBinning(10))
1826
estimates = Vector{Float64}(undef, 1_000)
1927

28+
# categorical distributions
2029
for nclasses in (2, 10, 100)
2130
dist = Dirichlet(nclasses, 1.0)
2231

@@ -32,6 +41,18 @@
3241

3342
@test all(x -> zero(x) < x < one(x), estimates)
3443
end
44+
45+
# probabilities
46+
predictions = Vector{Float64}(undef, 20)
47+
targets = Vector{Bool}(undef, 20)
48+
for i in 1:length(estimates)
49+
rand!(predictions)
50+
map!(targets, predictions) do p
51+
rand() < p
52+
end
53+
estimates[i] = ece(predictions, targets)
54+
end
55+
@test all(x -> zero(x) < x < one(x), estimates)
3556
end
3657

3758
@testset "Median variance binning: Basic properties" begin

0 commit comments

Comments
 (0)