diff --git a/Project.toml b/Project.toml index 9d9a030..cf9d07c 100644 --- a/Project.toml +++ b/Project.toml @@ -4,22 +4,31 @@ authors = ["Ronny Bergmann "] version = "0.1.0" [deps] +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +MLJScientificTypes = "2e2323e0-db8b-457b-ae0d-bdfb3bc63afd" Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] -Manifolds = "0.3" +MLJBase = "0.15" +MLJModelInterface = "0.3" +Manifolds = "0.3, 0.4" ManifoldsBase = "0.9" Manopt = "0.2" julia = "1.0" [extras] +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Distances", "MultivariateStats", "Test"] +test = ["DataFrames", "Distances", "MultivariateStats", "Test"] diff --git a/src/ManifoldML.jl b/src/ManifoldML.jl index 7ec2954..9413ef5 100644 --- a/src/ManifoldML.jl +++ b/src/ManifoldML.jl @@ -3,10 +3,16 @@ using ManifoldsBase using Manopt import Manopt: initialize_solver!, step_solver! using Manifolds: mean +using MLJBase +using MLJModelInterface +using MLJScientificTypes using Requires +using Tables # needed for MLJ + include("kmeans.jl") +include("tangent_transformer.jl") function __init__() @require Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" begin diff --git a/src/kmeans.jl b/src/kmeans.jl index 41537de..7608d02 100644 --- a/src/kmeans.jl +++ b/src/kmeans.jl @@ -27,11 +27,19 @@ struct KMeansOptions{P} <: Options centers::Vector{P} assignment::Vector{<:Int} stop::StoppingCriterion - function KMeansOptions{P}(points::Vector{P}, centers::Vector{P}, stop::StoppingCriterion) where {P} - return new(points, centers, zeros(Int,length(points)), stop) + function KMeansOptions{P}( + points::Vector{P}, + centers::Vector{P}, + stop::StoppingCriterion, + ) where {P} + return new(points, centers, zeros(Int, length(points)), stop) end end -function KMeansOptions(points::Vector{P}, centers::Vector{P}, stop::StoppingCriterion=StopAfterIteration(100)) where {P} +function KMeansOptions( + points::Vector{P}, + centers::Vector{P}, + stop::StoppingCriterion = StopAfterIteration(100), +) where {P} return KMeansOptions{P}(points, centers, stop) end @@ -40,26 +48,26 @@ end Store the fixed data necessary for [`kmeans`](@ref), i.e. only a `Manifold M`. """ -struct KMeansProblem{TM <: Manifold} <: Problem +struct KMeansProblem{TM<:Manifold} <: Problem M::TM end function initialize_solver!(p::KMeansProblem, o::KMeansOptions) - k_means_update_assignment!(p,o) + return k_means_update_assignment!(p, o) end function step_solver!(p::KMeansProblem, o::KMeansOptions, ::Int) # (1) Update assignments - k_means_update_assignment!(p,o) + k_means_update_assignment!(p, o) # (2) Update centers - for i=1:length(o.centers) - any(o.assignment==i) && mean!(p.M, o.centers[i], o.points[o.assignment==i]) + for i in 1:length(o.centers) + any(o.assignment == i) && mean!(p.M, o.centers[i], o.points[o.assignment == i]) end end function k_means_update_assignment!(p::KMeansProblem, o::KMeansOptions) - for i=1:length(o.points) - o.assignment[i] = argmin([ distance(p.M,o.points[i],c) for c in o.centers ] ) + for i in 1:length(o.points) + o.assignment[i] = argmin([distance(p.M, o.points[i], c) for c in o.centers]) end end @@ -80,15 +88,17 @@ decorators from [Manopt.jl](https://manoptjl.org) Returns the final [`KMeansOptions`](@ref) including the final assignment vector and the centers. """ -function kmeans(M::Manifold, pts::Vector{P}; +function kmeans( + M::Manifold, + pts::Vector{P}; num_centers = 5, centers = pts[1:num_centers], - stop=StopAfterIteration(100), - kwargs... - ) where {P} + stop = StopAfterIteration(100), + kwargs..., +) where {P} p = KMeansProblem(M) - o = KMeansOptions(pts,centers,stop) + o = KMeansOptions(pts, centers, stop) o = decorate_options(o; kwargs...) - oR = solve(p,o) + oR = solve(p, o) return get_options(oR) -end \ No newline at end of file +end diff --git a/src/tangent_transformer.jl b/src/tangent_transformer.jl new file mode 100644 index 0000000..c93047f --- /dev/null +++ b/src/tangent_transformer.jl @@ -0,0 +1,300 @@ + +const MLJManifoldPoint = Tuple{Any,Manifold} + + +mutable struct UnivariateTangentSpaceTransformer <: Unsupervised + p::Union{Symbol,MLJManifoldPoint} + retraction::AbstractRetractionMethod + inverse_retraction::AbstractInverseRetractionMethod +end + +function UnivariateTangentSpaceTransformer() + return UnivariateTangentSpaceTransformer( + :mean, + ExponentialRetraction(), + LogarithmicInverseRetraction(), + ) +end + +""" + TangentSpaceTransformer(p = :mean) + +Unsupervised model for transforming data to tangent space. +`p` is either equal to `:mean`, in which case data will be transformed to the tangent +space at the mean of points given to `fit`, or a specific point on a manifold (represented as a tuple). +""" +mutable struct TangentSpaceTransformer <: Unsupervised + p::Union{Symbol,MLJManifoldPoint} + retraction::AbstractRetractionMethod + inverse_retraction::AbstractInverseRetractionMethod + basis::ManifoldsBase.AbstractBasis + features::AbstractVector{Symbol} # if not empty only these features will be transformed +end + +function TangentSpaceTransformer() + return TangentSpaceTransformer( + :mean, + ExponentialRetraction(), + LogarithmicInverseRetraction(), + DefaultOrthonormalBasis(), + Symbol[], + ) +end + +function univariate_to_tspace_fit( + M::Manifold, + p, + v; + retraction::AbstractRetractionMethod, + inverse_retraction::AbstractInverseRetractionMethod = LogarithmicInverseRetraction(), + basis = DefaultOrthonormalBasis(), +) + if p === :mean + point = ( + mean( + M, + map(q -> q[1], v); + retraction = retraction, + inverse_retraction = inverse_retraction, + ), + M, + ) + else + point = p + end + fitresult = (point..., basis) + cache = nothing + report = NamedTuple() + return fitresult, cache, report +end + +function univariate_to_tspace_transform( + M::Manifold, + p, + basis, + data, + inverse_retraction::AbstractInverseRetractionMethod, +) + inv_retrs = map( + q -> get_coordinates(M, p, inverse_retract(M, p, q[1], inverse_retraction), basis), + data, + ) + return map(i -> map(u -> inv_retrs[u][i], 1:length(inv_retrs)), 1:manifold_dimension(M)) +end + +function univariate_to_tspace_inverse_transform( + M::Manifold, + p, + basis, + data, + retraction::AbstractRetractionMethod, +) + arr_data = Array(data) + retrs = map( + i -> retract(M, p, get_vector(M, p, arr_data[i, :], basis), retraction), + axes(arr_data, 1), + ) + return map(u -> (u, M), retrs) +end + +function MLJBase.fit(transformer::TangentSpaceTransformer, verbosity::Int, X) + is_univariate = !Tables.istable(X) + + if is_univariate + M = X[1][2] + return ( + ( + is_univariate = true, + fitresult = univariate_to_tspace_fit( + M, + transformer.p, + X; + basis = transformer.basis, + retraction = transformer.retraction, + inverse_retraction = transformer.inverse_retraction, + )[1], + ), + nothing, + nothing, + ) + end + + all_features = Tables.schema(X).names + feature_scitypes = collect(elscitype(selectcols(X, c)) for c in all_features) + if isempty(transformer.features) + cols_to_fit = filter!(collect(eachindex(all_features))) do j + return feature_scitypes[j] <: MLJScientificTypes.ManifoldPoint + end + else + issubset(transformer.features, all_features) || + @warn "Some specified features not present in table to be fit. " + cols_to_fit = filter!(collect(eachindex(all_features))) do j + return (all_features[j] in transformer.features) && + feature_scitypes[j] <: MLJScientificTypes.ManifoldPoint + end + end + fitresult_given_feature = Dict{Symbol,Any}() + + isempty(cols_to_fit) && verbosity > -1 && @warn "No features to transform" + for j in cols_to_fit + col_data = selectcols(X, j) + col_fitresult, cache, report = univariate_to_tspace_fit( + col_data[1][2], + transformer.p, + col_data; + retraction = transformer.retraction, + inverse_retraction = transformer.inverse_retraction, + basis = transformer.basis, + ) + fitresult_given_feature[all_features[j]] = col_fitresult + end + + fitresult = (is_univariate = false, fitresult_given_feature = fitresult_given_feature) + + return fitresult, nothing, nothing +end + + +function MLJBase.fitted_params(::TangentSpaceTransformer, fitresult) + if fitresult.is_univariate + return (point = fitresult.fitresult[1:2], basis = fitresult.fitresult[3]) + else + error("TODO") + end +end + +# for transforming single value: +function MLJBase.transform( + transformer::UnivariateTangentSpaceTransformer, + fitresult, + p::MLJManifoldPoint, +) + q, M, basis = fitresult + X = inverse_retract(M, q, p[1], transformer.inverse_retraction) + coeffs = get_coefficients(M, q, X, basis) + new_features = [Symbol("X_$i") for i in 1:manifold_dimension(M)] + named_cols = NamedTuple{tuple(new_features...)}(tuple(coeffs)...) + return MLJBase.table(named_cols) +end + +# for transforming vector: +function MLJBase.transform(transformer::TangentSpaceTransformer, fitresult, ps) + is_univariate = fitresult.is_univariate + + if is_univariate + M = ps[1][2] + return univariate_to_tspace_transform( + M, + fitresult.fitresult[1], + transformer.basis, + ps, + transformer.inverse_retraction, + ) + end + + features_to_be_transformed = keys(fitresult.fitresult_given_feature) + + all_features = Tables.schema(ps).names + + all(e -> e in all_features, features_to_be_transformed) || + error("Attempting to transform data with incompatible feature labels.") + + new_features = Symbol[] + new_cols = [] + + + for ftr in all_features + ftr_data = selectcols(ps, ftr) + if ftr in features_to_be_transformed + fgf = fitresult.fitresult_given_feature[ftr] + M = fgf[2] + feature_names = [Symbol("$(ftr)_$i") for i in 1:manifold_dimension(M)] + append!(new_features, feature_names) + cols = univariate_to_tspace_transform( + M, + fgf[1], + fgf[3], + ftr_data, + transformer.inverse_retraction, + ) + append!(new_cols, cols) + else + push!(new_features, ftr) + push!(new_cols, ftr_data) + end + end + + named_cols = NamedTuple{tuple(new_features...)}(tuple(new_cols...)) + + return MLJBase.table(named_cols, prototype = ps) +end + +function MLJBase.inverse_transform(transformer::TangentSpaceTransformer, fitresult, X) + is_univariate = fitresult.is_univariate + + if is_univariate + return inverse_transform(transformer, fitresult.fitresult, X) + end + + features_transformed = keys(fitresult.fitresult_given_feature) + features_processed = Symbol[] + + all_features = Tables.schema(X).names + + new_features = Symbol[] + new_cols = [] + + for ftr in all_features + sftc = string(ftr) + last_underscore = findlast('_', sftc) + prefix = last_underscore === nothing ? ftr : Symbol(sftc[1:(last_underscore - 1)]) + if prefix in features_transformed + if prefix in features_processed + continue + end + fgf = fitresult.fitresult_given_feature[prefix] + M = fgf[2] + feature_names = [Symbol("$(prefix)_$i") for i in 1:manifold_dimension(M)] + push!(new_features, prefix) + ftr_data = selectcols(X, feature_names) + new_col = univariate_to_tspace_inverse_transform( + M, + fgf[1], + fgf[3], + ftr_data, + transformer.retraction, + ) + push!(new_cols, new_col) + push!(features_processed, prefix) + else + ftr_data = selectcols(X, ftr) + push!(new_features, ftr) + push!(new_cols, ftr_data) + end + end + named_cols = NamedTuple{tuple(new_features...)}(tuple(new_cols...)) + + return MLJBase.table(named_cols, prototype = X) +end + + +# for single values: +function MLJBase.inverse_transform( + transformer::UnivariateTangentSpaceTransformer, + fitresult, + coeffs, +) + q, M, basis = fitresult + X = get_vector(M, q, coeffs, basis) + p = retract(M, q, X, transformer.retraction) + return (p, M) +end + +# for vectors: +function MLJBase.inverse_transform( + transformer::UnivariateTangentSpaceTransformer, + fitresult, + w::AbstractVector, +) + return [MLJBase.inverse_transform(transformer, fitresult, y) for y in w] +end diff --git a/test/runtests.jl b/test/runtests.jl index 2e5e51c..487e00d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,3 +2,4 @@ include("utils.jl") include("test_kmeans.jl") include("test_distances.jl") +include("test_mlj.jl") diff --git a/test/test_distances.jl b/test/test_distances.jl index 3f4710e..deb06d9 100644 --- a/test/test_distances.jl +++ b/test/test_distances.jl @@ -12,7 +12,8 @@ Random.seed!(42) A(π / 6) * [1.0 0.0 0.0; 0.0 2.0 0.0; 0.0 0.0 1] * transpose(A(π / 6)), ] dist = ManifoldML.RiemannianDistance(M) - @test evaluate(dist, reshape(ptsF[1], 9), reshape(ptsF[3], 9)) ≈ distance(M, ptsF[1], ptsF[3]) + @test Distances.evaluate(dist, reshape(ptsF[1], 9), reshape(ptsF[3], 9)) ≈ + distance(M, ptsF[1], ptsF[3]) point_matrix = reduce(hcat, map(a -> reshape(a, 9), ptsF)) dists = pairwise(dist, point_matrix) diff --git a/test/test_kmeans.jl b/test/test_kmeans.jl index 67b87e6..e1590ce 100644 --- a/test/test_kmeans.jl +++ b/test/test_kmeans.jl @@ -4,15 +4,15 @@ Random.seed!(42) @testset "K-Means" begin M = Sphere(2) - p1 = 1/sqrt(3) .* [1.0, 1.0, 1.0] - p2 = 1/sqrt(3) .* [-1.0, -1.0, 1.0] - c1 = [exp(M,p1, project(M, p1, 0.7 .* (rand(3) .- 0.5 ))) for i=1:20] - c2 = [exp(M,p2, project(M, p2, 0.7 .* (rand(3) .- 0.5 ))) for i=1:20] + p1 = 1 / sqrt(3) .* [1.0, 1.0, 1.0] + p2 = 1 / sqrt(3) .* [-1.0, -1.0, 1.0] + c1 = [exp(M, p1, project(M, p1, 0.7 .* (rand(3) .- 0.5))) for i in 1:20] + c2 = [exp(M, p2, project(M, p2, 0.7 .* (rand(3) .- 0.5))) for i in 1:20] pts = [p1, p2, c1..., c2...] - o = kmeans(M, pts; num_centers=2) - @test distance(M,p1,o.centers[1]) ≈ 0 - @test distance(M,p2,o.centers[2]) ≈ 0 + o = kmeans(M, pts; num_centers = 2) + @test distance(M, p1, o.centers[1]) ≈ 0 + @test distance(M, p2, o.centers[2]) ≈ 0 @test sum(o.assignment .== 1) == 21 @test sum(o.assignment .== 2) == 21 -end \ No newline at end of file +end diff --git a/test/test_mlj.jl b/test/test_mlj.jl new file mode 100644 index 0000000..2f93368 --- /dev/null +++ b/test/test_mlj.jl @@ -0,0 +1,59 @@ +include("utils.jl") +using MLJBase +using MLJModelInterface +using DataFrames + +@testset "MLJ interoperability" begin + M = Sphere(2) + p1 = [1.0, 0.0, 0.0] + p2 = [0.0, 1.0, 0.0] + p3 = [0.0, sqrt(2) / 2, -sqrt(2) / 2] + + X = DataFrame(pm = [(p1, M), (p2, M), (p3, M)], y = [1, 2, 1]) + + tst_model = ManifoldML.TangentSpaceTransformer( + :mean, + ExponentialRetraction(), + LogarithmicInverseRetraction(), + DefaultOrthogonalBasis(), + [:pm], + ) + pm = mean(M, [p1, p2, p3]) + + logs = map( + y -> get_coordinates(M, pm, log(M, pm, y), DefaultOrthogonalBasis()), + [p1, p2, p3], + ) + + @testset "univariate" begin + fitted_univariate = fit!(machine(tst_model, X[:pm])) + transformed_univariate = MLJBase.transform(fitted_univariate, X[:pm]) + for i in 1:2 + @test transformed_univariate[i] == map(y -> y[i], logs) + end + end + + @testset "zero argument constructor" begin + default_tst_model = ManifoldML.TangentSpaceTransformer() + @test default_tst_model.p === :mean + end + + fitted_model = fit!(machine(tst_model, X)) + transformed = MLJBase.transform(fitted_model, X) + + @testset "forward transform" begin + @test schema(transformed).names == (:pm_1, :pm_2, :y) + @test schema(transformed).types == (Float64, Float64, Int64) + @test schema(transformed).scitypes == (Continuous, Continuous, Count) + + for i in 1:2 + @test transformed[Symbol("pm_$i")] == map(y -> y[i], logs) + end + end + + inv_transformed = MLJBase.inverse_transform(fitted_model, transformed) + @testset "inverse transform" begin + @test schema(inv_transformed) == schema(X) + @test isapprox(map(p -> p[1], X[:pm]), map(p -> p[1], inv_transformed[:pm])) + end +end diff --git a/test/utils.jl b/test/utils.jl index c16c90b..370fefc 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,2 +1,2 @@ using ManifoldML, Manifolds, ManifoldsBase, Test -using Random \ No newline at end of file +using Random