1+ # ####
2+ # #### Wasserstein distance
3+ # ####
4+
5+ abstract type Side end
6+ struct Left <: Side end
7+ struct Right <: Side end
8+
9+ """
10+ pysearchsorted(a,b;side="left")
11+
12+ Based on accepted answer in:
13+ https://stackoverflow.com/questions/55339848/julia-vectorized-version-of-searchsorted
14+ """
15+ pysearchsorted (a,b,:: Left ) = searchsortedfirst .(Ref (a),b) .- 1
16+ pysearchsorted (a,b,:: Right ) = searchsortedlast .(Ref (a),b)
17+
18+ function compute_integral (u_cdf, v_cdf, deltas, p)
19+ if p == 1
20+ return sum (abs .(u_cdf - v_cdf) .* deltas)
21+ end
22+ if p == 2
23+ return sqrt (sum ((u_cdf - v_cdf). ^ 2 .* deltas))
24+ end
25+ return sum (abs .(u_cdf - v_cdf).^ p .* deltas)^ (1 / p)
26+ end
27+
28+ function _cdf_distance (p, u_values, v_values, u_weights= nothing , v_weights= nothing )
29+ _validate_distribution (u_values, u_weights)
30+ _validate_distribution (v_values, v_weights)
31+
32+ u_sorter = sortperm (u_values)
33+ v_sorter = sortperm (v_values)
34+
35+ all_values = vcat (u_values, v_values)
36+ sort! (all_values)
37+
38+ # Compute the differences between pairs of successive values of u and v.
39+ deltas = diff (all_values)
40+
41+ # Get the respective positions of the values of u and v among the values of
42+ # both distributions.
43+ u_cdf_indices = pysearchsorted (u_values[u_sorter],all_values[1 : end - 1 ], Right ())
44+ v_cdf_indices = pysearchsorted (v_values[v_sorter],all_values[1 : end - 1 ], Right ())
45+
46+ # Calculate the CDFs of u and v using their weights, if specified.
47+ if u_weights == nothing
48+ u_cdf = (u_cdf_indices) / length (u_values)
49+ else
50+ u_sorted_cumweights = vcat ([0 ], cumsum (u_weights[u_sorter]))
51+ u_cdf = u_sorted_cumweights[u_cdf_indices.+ 1 ] / u_sorted_cumweights[end ]
52+ end
53+
54+ if v_weights == nothing
55+ v_cdf = (v_cdf_indices) / length (v_values)
56+ else
57+ v_sorted_cumweights = vcat ([0 ], cumsum (v_weights[v_sorter]))
58+ v_cdf = v_sorted_cumweights[v_cdf_indices.+ 1 ] / v_sorted_cumweights[end ]
59+ end
60+
61+ # Compute the value of the integral based on the CDFs.
62+ return compute_integral (u_cdf, v_cdf, deltas, p)
63+ end
64+
65+ function _validate_distribution (vals, weights)
66+ # Validate the value array.
67+ length (vals) == 0 && throw (ArgumentError (" Distribution can't be empty." ))
68+ # Validate the weight array, if specified.
69+ if weights ≠ nothing
70+ if length (weights) != length (vals)
71+ throw (DimensionMismatch (" Value and weight array-likes for the same empirical distribution must be of the same size." ))
72+ end
73+ any (weights .< 0 ) && throw (ArgumentError (" All weights must be non-negative." ))
74+ if ! (0 < sum (weights) < Inf )
75+ throw (ArgumentError (" Weight array-like sum must be positive and finite. Set as None for an equal distribution of weight." ))
76+ end
77+ end
78+ return nothing
79+ end
80+
81+ """
82+ wasserstein_distance(u_values, v_values, u_weights=nothing, v_weights=nothing)
83+
84+ Compute the first Wasserstein distance between two 1D distributions.
85+ This distance is also known as the earth mover's distance, since it can be
86+ seen as the minimum amount of "work" required to transform ``u`` into
87+ ``v``, where "work" is measured as the amount of distribution weight
88+ that must be moved, multiplied by the distance it has to be moved.
89+
90+ - `u_values` Values observed in the (empirical) distribution.
91+ - `v_values` Values observed in the (empirical) distribution.
92+
93+ - `u_weights` Weight for each value.
94+ - `v_weights` Weight for each value.
95+
96+ If the weight sum differs from 1, it must still be positive
97+ and finite so that the weights can be normalized to sum to 1.
98+ """
99+ function wasserstein_distance (u_values, v_values, u_weights= nothing , v_weights= nothing )
100+ return _cdf_distance (1 , u_values, v_values, u_weights, v_weights)
101+ end
0 commit comments