Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,31 @@ authors = ["Ronny Bergmann <[email protected]>"]
version = "0.1.0"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLJScientificTypes = "2e2323e0-db8b-457b-ae0d-bdfb3bc63afd"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Manifolds = "0.3"
MLJBase = "0.15"
MLJModelInterface = "0.3"
Manifolds = "0.3, 0.4"
ManifoldsBase = "0.9"
Manopt = "0.2"
julia = "1.0"

[extras]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Distances", "MultivariateStats", "Test"]
test = ["DataFrames", "Distances", "MultivariateStats", "Test"]
6 changes: 6 additions & 0 deletions src/ManifoldML.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ using ManifoldsBase
using Manopt
import Manopt: initialize_solver!, step_solver!
using Manifolds: mean
using MLJBase
using MLJModelInterface
using MLJScientificTypes

using Requires

using Tables # needed for MLJ

include("kmeans.jl")
include("tangent_transformer.jl")

function __init__()
@require Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" begin
Expand Down
44 changes: 27 additions & 17 deletions src/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,19 @@ struct KMeansOptions{P} <: Options
centers::Vector{P}
assignment::Vector{<:Int}
stop::StoppingCriterion
function KMeansOptions{P}(points::Vector{P}, centers::Vector{P}, stop::StoppingCriterion) where {P}
return new(points, centers, zeros(Int,length(points)), stop)
function KMeansOptions{P}(
points::Vector{P},
centers::Vector{P},
stop::StoppingCriterion,
) where {P}
return new(points, centers, zeros(Int, length(points)), stop)
end
end
function KMeansOptions(points::Vector{P}, centers::Vector{P}, stop::StoppingCriterion=StopAfterIteration(100)) where {P}
function KMeansOptions(
points::Vector{P},
centers::Vector{P},
stop::StoppingCriterion = StopAfterIteration(100),
) where {P}
return KMeansOptions{P}(points, centers, stop)
end

Expand All @@ -40,26 +48,26 @@ end

Store the fixed data necessary for [`kmeans`](@ref), i.e. only a `Manifold M`.
"""
struct KMeansProblem{TM <: Manifold} <: Problem
struct KMeansProblem{TM<:Manifold} <: Problem
M::TM
end

function initialize_solver!(p::KMeansProblem, o::KMeansOptions)
k_means_update_assignment!(p,o)
return k_means_update_assignment!(p, o)
end

function step_solver!(p::KMeansProblem, o::KMeansOptions, ::Int)
# (1) Update assignments
k_means_update_assignment!(p,o)
k_means_update_assignment!(p, o)
# (2) Update centers
for i=1:length(o.centers)
any(o.assignment==i) && mean!(p.M, o.centers[i], o.points[o.assignment==i])
for i in 1:length(o.centers)
any(o.assignment == i) && mean!(p.M, o.centers[i], o.points[o.assignment == i])
end
end

function k_means_update_assignment!(p::KMeansProblem, o::KMeansOptions)
for i=1:length(o.points)
o.assignment[i] = argmin([ distance(p.M,o.points[i],c) for c in o.centers ] )
for i in 1:length(o.points)
o.assignment[i] = argmin([distance(p.M, o.points[i], c) for c in o.centers])
end
end

Expand All @@ -80,15 +88,17 @@ decorators from [Manopt.jl](https://manoptjl.org)

Returns the final [`KMeansOptions`](@ref) including the final assignment vector and the centers.
"""
function kmeans(M::Manifold, pts::Vector{P};
function kmeans(
M::Manifold,
pts::Vector{P};
num_centers = 5,
centers = pts[1:num_centers],
stop=StopAfterIteration(100),
kwargs...
) where {P}
stop = StopAfterIteration(100),
kwargs...,
) where {P}
p = KMeansProblem(M)
o = KMeansOptions(pts,centers,stop)
o = KMeansOptions(pts, centers, stop)
o = decorate_options(o; kwargs...)
oR = solve(p,o)
oR = solve(p, o)
return get_options(oR)
end
end
Loading