Skip to content

Commit a140a02

Browse files
committed
First-class support of SKCE for probability vectors and boolean targets (#85)
1 parent 95ea24c commit a140a02

File tree

7 files changed

+191
-54
lines changed

7 files changed

+191
-54
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CalibrationErrors"
22
uuid = "33913031-fe46-5864-950f-100836f47845"
33
authors = ["David Widmann <[email protected]>"]
4-
version = "0.5.16"
4+
version = "0.5.17"
55

66
[deps]
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

src/skce/generic.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,20 @@ function unsafe_skce_eval(
8888
return result
8989
end
9090

91+
# for binary classification with probabilities (corresponding to parameters of Bernoulli
92+
# distributions) and boolean targets the expression simplifies to
93+
# ```math
94+
# k((p, y), (p̃, ỹ)) = (y(1-p) + (1-y)p)(ỹ(1-p̃) + (1-ỹ)p̃)(k((p, y), (p̃, ỹ)) - k((p, 1-y), (p̃, ỹ)) - k((p, y), (p̃, 1-ỹ)) + k((p, 1-y), (p̃, 1-ỹ)))
95+
# ```
96+
function unsafe_skce_eval(kernel::Kernel, p::Real, y::Bool, p̃::Real, ỹ::Bool)
97+
noty = !y
98+
notỹ = !
99+
z =
100+
kernel((p, y), (p̃, ỹ)) - kernel((p, noty), (p̃, ỹ)) -
101+
kernel((p, y), (p̃, notỹ)) + kernel((p, noty), (p̃, notỹ))
102+
return (y ? 1 - p : p) * (ỹ ? 1 -: p̃) * z
103+
end
104+
91105
# evaluation for tensor product kernels
92106
function unsafe_skce_eval(kernel::KernelTensorProduct, p, y, p̃, ỹ)
93107
κpredictions, κtargets = kernel.kernels
@@ -105,6 +119,10 @@ function unsafe_skce_eval(
105119
κpredictions, κtargets = kernel.kernels
106120
return κpredictions(p, p̃) * unsafe_skce_eval_targets(κtargets, p, y, p̃, ỹ)
107121
end
122+
function unsafe_skce_eval(kernel::KernelTensorProduct, p::Real, y::Bool, p̃::Real, ỹ::Bool)
123+
κpredictions, κtargets = kernel.kernels
124+
return κpredictions(p, p̃) * unsafe_skce_eval_targets(κtargets, p, y, p̃, ỹ)
125+
end
108126

109127
function unsafe_skce_eval_targets(
110128
κtargets::Kernel,
@@ -258,3 +276,13 @@ function unsafe_skce_eval_targets(
258276
@inbounds res = (y == ỹ) - p[ỹ] - p̃[y] + dot(p, p̃)
259277
return res
260278
end
279+
280+
function unsafe_skce_eval_targets(κtargets::Kernel, p::Real, y::Bool, p̃::Real, ỹ::Bool)
281+
noty = !y
282+
notỹ = !
283+
z = κtargets(y, ỹ) - κtargets(noty, ỹ) - κtargets(y, notỹ) + κtargets(noty, notỹ)
284+
return (y ? 1 - p : p) * (ỹ ? 1 -: p̃) * z
285+
end
286+
function unsafe_skce_eval_targets(::WhiteKernel, p::Real, y::Bool, p̃::Real, ỹ::Bool)
287+
return 2 * (y - p) * (ỹ - p̃)
288+
end

test/kernels.jl

Lines changed: 10 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,15 @@
11
@testset "kernels.jl" begin
2-
# alternative implementation of white kernel
3-
struct WhiteKernel2 <: Kernel end
4-
(::WhiteKernel2)(x, y) = x == y
2+
kernel = TVExponentialKernel()
53

6-
# alternative implementation TensorProductKernel
7-
struct TensorProduct2{K1<:Kernel,K2<:Kernel} <: Kernel
8-
kernel1::K1
9-
kernel2::K2
10-
end
11-
function (kernel::TensorProduct2)((x1, x2), (y1, y2))
12-
return kernel.kernel1(x1, y1) * kernel.kernel2(x2, y2)
13-
end
4+
# traits
5+
@test KernelFunctions.metric(kernel) === TotalVariation()
146

15-
@testset "TVExponentialKernel" begin
16-
kernel = TVExponentialKernel()
7+
# simple evaluation
8+
x, y = rand(10), rand(10)
9+
@test kernel(x, y) == exp(-totalvariation(x, y))
1710

18-
# traits
19-
@test KernelFunctions.metric(kernel) === TotalVariation()
20-
21-
# simple evaluation
22-
x, y = rand(10), rand(10)
23-
@test kernel(x, y) == exp(-totalvariation(x, y))
24-
25-
# transformations
26-
@test (kernel ScaleTransform(0.1))(x, y) == exp(-0.1 * totalvariation(x, y))
27-
ard = rand(10)
28-
@test (kernel ARDTransform(ard))(x, y) == exp(-totalvariation(ard .* x, ard .* y))
29-
end
30-
31-
@testset "unsafe_skce_eval" begin
32-
kernel = SqExponentialKernel()
33-
kernel1 = kernel WhiteKernel()
34-
kernel2 = kernel WhiteKernel2()
35-
kernel3 = TensorProduct2(kernel, WhiteKernel())
36-
37-
x1, x2 = rand(10), rand(1:10)
38-
39-
@test CalibrationErrors.unsafe_skce_eval(kernel1, x1, x2, x1, x2)
40-
CalibrationErrors.unsafe_skce_eval(kernel2, x1, x2, x1, x2)
41-
@test CalibrationErrors.unsafe_skce_eval(kernel1, x1, x2, x1, x2)
42-
CalibrationErrors.unsafe_skce_eval(kernel3, x1, x2, x1, x2)
43-
44-
y1, y2 = rand(10), rand(1:10)
45-
46-
@test CalibrationErrors.unsafe_skce_eval(kernel1, x1, x2, y1, y2)
47-
CalibrationErrors.unsafe_skce_eval(kernel2, x1, x2, y1, y2)
48-
@test CalibrationErrors.unsafe_skce_eval(kernel1, x1, x2, y1, y2)
49-
CalibrationErrors.unsafe_skce_eval(kernel3, x1, x2, y1, y2)
50-
end
11+
# transformations
12+
@test (kernel ScaleTransform(0.1))(x, y) == exp(-0.1 * totalvariation(x, y))
13+
ard = rand(10)
14+
@test (kernel ARDTransform(ard))(x, y) == exp(-totalvariation(ard .* x, ard .* y))
5115
end

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ using Random
88
using Statistics
99
using Test
1010

11+
using CalibrationErrors: unsafe_skce_eval
12+
1113
Random.seed!(1234)
1214

1315
@testset "CalibrationErrors" begin
@@ -32,6 +34,9 @@ Random.seed!(1234)
3234
end
3335

3436
@testset "SKCE" begin
37+
@testset "generic" begin
38+
include("skce/generic.jl")
39+
end
3540
@testset "biased" begin
3641
include("skce/biased.jl")
3742
end

test/skce/biased.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
@testset "biased.jl" begin
22
@testset "Two-dimensional example" begin
3+
# categorical distributions
34
skce = BiasedSKCE(SqExponentialKernel() WhiteKernel())
4-
5-
# only two predictions, i.e., three unique terms in the estimator
65
@test iszero(@inferred(skce([[1, 0], [0, 1]], [1, 2])))
76
@test @inferred(skce([[1, 0], [0, 1]], [1, 1])) 0.5
87
@test @inferred(skce([[1, 0], [0, 1]], [2, 1])) 1 - exp(-1)
98
@test @inferred(skce([[1, 0], [0, 1]], [2, 2])) 0.5
9+
10+
# probabilities
11+
skce = BiasedSKCE((SqExponentialKernel() ScaleTransform(sqrt(2))) WhiteKernel())
12+
@test iszero(@inferred(skce([1, 0], [true, false])))
13+
@test @inferred(skce([1, 0], [true, true])) 0.5
14+
@test @inferred(skce([1, 0], [false, true])) 1 - exp(-1)
15+
@test @inferred(skce([1, 0], [false, false])) 0.5
1016
end
1117

1218
@testset "Basic properties" begin
1319
skce = BiasedSKCE((ExponentialKernel() ScaleTransform(0.1)) WhiteKernel())
1420
estimates = Vector{Float64}(undef, 1_000)
1521

22+
# categorical distributions
1623
for nclasses in (2, 10, 100)
1724
dist = Dirichlet(nclasses, 1.0)
1825

@@ -27,5 +34,17 @@
2734

2835
@test all(x -> x > zero(x), estimates)
2936
end
37+
38+
# probabilities
39+
predictions = Vector{Float64}(undef, 20)
40+
targets = Vector{Bool}(undef, 20)
41+
for i in 1:length(estimates)
42+
rand!(predictions)
43+
map!(targets, predictions) do p
44+
rand() < p
45+
end
46+
estimates[i] = skce(predictions, targets)
47+
end
48+
@test all(x -> x > zero(x), estimates)
3049
end
3150
end

test/skce/generic.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
@testset "generic.jl" begin
2+
# alternative implementation of white kernel
3+
struct WhiteKernel2 <: Kernel end
4+
(::WhiteKernel2)(x, y) = x == y
5+
6+
# alternative implementation TensorProductKernel
7+
struct TensorProduct2{K1<:Kernel,K2<:Kernel} <: Kernel
8+
kernel1::K1
9+
kernel2::K2
10+
end
11+
function (kernel::TensorProduct2)((x1, x2), (y1, y2))
12+
return kernel.kernel1(x1, y1) * kernel.kernel2(x2, y2)
13+
end
14+
15+
@testset "binary classification" begin
16+
# probabilities and boolean targets
17+
p, p̃ = rand(2)
18+
y, ỹ = rand(Bool, 2)
19+
scale = rand()
20+
kernel = SqExponentialKernel() ScaleTransform(scale)
21+
val = unsafe_skce_eval(kernel WhiteKernel(), p, y, p̃, ỹ)
22+
@test unsafe_skce_eval(kernel WhiteKernel2(), p, y, p̃, ỹ) val
23+
@test unsafe_skce_eval(TensorProduct2(kernel, WhiteKernel()), p, y, p̃, ỹ) val
24+
@test unsafe_skce_eval(TensorProduct2(kernel, WhiteKernel2()), p, y, p̃, ỹ) val
25+
26+
# corresponding values and kernel for full categorical distribution
27+
pfull = [p, 1 - p]
28+
yint = y ? 1 : 2
29+
p̃full = [p̃, 1 - p̃]
30+
ỹint =? 1 : 2
31+
kernelfull = SqExponentialKernel() ScaleTransform(scale / sqrt(2))
32+
33+
@test unsafe_skce_eval(kernelfull WhiteKernel(), pfull, yint, p̃full, ỹint) val
34+
@test unsafe_skce_eval(kernelfull WhiteKernel2(), pfull, yint, p̃full, ỹint)
35+
val
36+
@test unsafe_skce_eval(
37+
TensorProduct2(kernelfull, WhiteKernel()), pfull, yint, p̃full, ỹint
38+
) val
39+
@test unsafe_skce_eval(
40+
TensorProduct2(kernelfull, WhiteKernel2()), pfull, yint, p̃full, ỹint
41+
) val
42+
end
43+
44+
@testset "multi-class classification" begin
45+
n = 10
46+
p = rand(n)
47+
p ./= sum(p)
48+
y = rand(1:n)
49+
= rand(n)
50+
./= sum(p̃)
51+
= rand(1:n)
52+
53+
kernel = SqExponentialKernel() ScaleTransform(rand())
54+
val = unsafe_skce_eval(kernel WhiteKernel(), p, y, p̃, ỹ)
55+
56+
@test unsafe_skce_eval(kernel WhiteKernel2(), p, y, p̃, ỹ) val
57+
@test unsafe_skce_eval(TensorProduct2(kernel, WhiteKernel()), p, y, p̃, ỹ) val
58+
@test unsafe_skce_eval(TensorProduct2(kernel, WhiteKernel2()), p, y, p̃, ỹ) val
59+
end
60+
end

test/skce/unbiased.jl

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
@testset "unbiased.jl" begin
22
@testset "Unbiased: Two-dimensional example" begin
3+
# categorical distributions
34
skce = UnbiasedSKCE(SqExponentialKernel() WhiteKernel())
4-
5-
# only two predictions, i.e., one term in the estimator
65
@test iszero(@inferred(skce([[1, 0], [0, 1]], [1, 2])))
76
@test iszero(@inferred(skce([[1, 0], [0, 1]], [1, 1])))
87
@test @inferred(skce([[1, 0], [0, 1]], [2, 1])) -2 * exp(-1)
98
@test iszero(@inferred(skce([[1, 0], [0, 1]], [2, 2])))
9+
10+
# probabilities
11+
skce = UnbiasedSKCE(
12+
(SqExponentialKernel() ScaleTransform(sqrt(2))) WhiteKernel()
13+
)
14+
@test iszero(@inferred(skce([1, 0], [true, false])))
15+
@test iszero(@inferred(skce([1, 0], [true, true])))
16+
@test @inferred(skce([1, 0], [false, true])) -2 * exp(-1)
17+
@test iszero(@inferred(skce([1, 0], [false, false])))
1018
end
1119

1220
@testset "Unbiased: Basic properties" begin
1321
skce = UnbiasedSKCE((ExponentialKernel() ScaleTransform(0.1)) WhiteKernel())
1422
estimates = Vector{Float64}(undef, 1_000)
1523

24+
# categorical distributions
1625
for nclasses in (2, 10, 100)
1726
dist = Dirichlet(nclasses, 1.0)
1827

@@ -30,13 +39,26 @@
3039
@test any(x -> x < zero(x), estimates)
3140
@test mean(estimates) 0 atol = 1e-3
3241
end
42+
43+
# probabilities
44+
predictions = Vector{Float64}(undef, 20)
45+
targets = Vector{Bool}(undef, 20)
46+
for i in 1:length(estimates)
47+
rand!(predictions)
48+
map!(targets, predictions) do p
49+
rand() < p
50+
end
51+
estimates[i] = skce(predictions, targets)
52+
end
53+
54+
@test any(x -> x > zero(x), estimates)
55+
@test any(x -> x < zero(x), estimates)
56+
@test mean(estimates) 0 atol = 1e-3
3357
end
3458

3559
@testset "Block: Two-dimensional example" begin
36-
# Blocks of two samples
60+
# categorical distributions
3761
skce = BlockUnbiasedSKCE(SqExponentialKernel() WhiteKernel())
38-
39-
# only two predictions, i.e., one term in the estimator
4062
@test iszero(@inferred(skce([[1, 0], [0, 1]], [1, 2])))
4163
@test iszero(@inferred(skce([[1, 0], [0, 1]], [1, 1])))
4264
@test @inferred(skce([[1, 0], [0, 1]], [2, 1])) -2 * exp(-1)
@@ -48,6 +70,21 @@
4870
@test @inferred(skce(repeat([[1, 0], [0, 1]], 10), repeat([2, 1], 10)))
4971
-2 * exp(-1)
5072
@test iszero(@inferred(skce(repeat([[1, 0], [0, 1]], 10), repeat([2, 2], 10))))
73+
74+
# probabilities
75+
skce = BlockUnbiasedSKCE(
76+
(SqExponentialKernel() ScaleTransform(sqrt(2))) WhiteKernel()
77+
)
78+
@test iszero(@inferred(skce([1, 0], [true, false])))
79+
@test iszero(@inferred(skce([1, 0], [true, true])))
80+
@test @inferred(skce([1, 0], [false, true])) -2 * exp(-1)
81+
@test iszero(@inferred(skce([1, 0], [false, false])))
82+
83+
# two predictions, ten times replicated
84+
@test iszero(@inferred(skce(repeat([1, 0], 10), repeat([true, false], 10))))
85+
@test iszero(@inferred(skce(repeat([1, 0], 10), repeat([true, true], 10))))
86+
@test @inferred(skce(repeat([1, 0], 10), repeat([false, true], 10))) -2 * exp(-1)
87+
@test iszero(@inferred(skce(repeat([1, 0], 10), repeat([false, false], 10))))
5188
end
5289

5390
@testset "Block: Basic properties" begin
@@ -58,6 +95,7 @@
5895
blockskce_all = BlockUnbiasedSKCE(kernel, nsamples)
5996
estimates = Vector{Float64}(undef, 1_000)
6097

98+
# categorical distributions
6199
for nclasses in (2, 10, 100)
62100
dist = Dirichlet(nclasses, 1.0)
63101

@@ -82,5 +120,28 @@
82120
@test any(x -> x < zero(x), estimates)
83121
@test mean(estimates) 0 atol = 5e-3
84122
end
123+
124+
# probabilities
125+
predictions = Vector{Float64}(undef, nsamples)
126+
targets = Vector{Bool}(undef, nsamples)
127+
128+
for i in 1:length(estimates)
129+
rand!(predictions)
130+
map!(targets, predictions) do p
131+
return rand() < p
132+
end
133+
estimates[i] = blockskce(predictions, targets)
134+
135+
# consistency checks
136+
@test estimates[i] mean(
137+
skce(predictions[(2 * i - 1):(2 * i)], targets[(2 * i - 1):(2 * i)]) for
138+
i in 1:(nsamples ÷ 2)
139+
)
140+
@test skce(predictions, targets) == blockskce_all(predictions, targets)
141+
end
142+
143+
@test any(x -> x > zero(x), estimates)
144+
@test any(x -> x < zero(x), estimates)
145+
@test mean(estimates) 0 atol = 5e-3
85146
end
86147
end

0 commit comments

Comments
 (0)