diff --git a/Cargo.toml b/Cargo.toml index 184656e91..237b311e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,11 @@ optional = true default-features = false features = ["cblas"] +[dependencies.linfa-kernel] +version = "0.4.0" +path = "algorithms/linfa-kernel" +optional = true + [dev-dependencies] ndarray-rand = "0.13" linfa-datasets = { path = "datasets", features = ["winequality", "iris", "diabetes"] } diff --git a/algorithms/linfa-hierarchical/src/lib.rs b/algorithms/linfa-hierarchical/src/lib.rs index 7f4e53641..3ad82adf9 100644 --- a/algorithms/linfa-hierarchical/src/lib.rs +++ b/algorithms/linfa-hierarchical/src/lib.rs @@ -19,12 +19,12 @@ //! [kodama](https://docs.rs/kodama/0.2.3/kodama/) crate. use std::collections::HashMap; +use ndarray::Array1; use kodama::linkage; pub use kodama::Method; -use linfa::dataset::DatasetBase; -use linfa::traits::Transformer; +use linfa::traits::PredictRef; use linfa::Float; use linfa_kernel::Kernel; @@ -40,9 +40,50 @@ enum Criterion { /// Agglomerative hierarchical clustering /// -/// In this clustering algorithm, each point is first considered as a separate cluster. During each -/// step, two points are merged into new clusters, until a stopping criterion is reached. The distance -/// between the points is computed as the negative-log transform of the similarity kernel. +/// Hierarchical clustering is a method of cluster analysis which seeks to build a hierarchy of +/// cluster. First each points is considered as a separate cluster. During each step, two points +/// are merged into new clusters, until a stopping criterion is reached. The distance between the +/// points is computed as the negative-log transform of the similarity kernel. +/// +/// # Example +/// +/// This example loads the iris flower dataset and performs hierarchical clustering into three +/// separate clusters. +/// ```rust +/// use std::error::Error; +/// +/// use linfa::traits::Transformer; +/// use linfa_hierarchical::HierarchicalCluster; +/// use linfa_kernel::{Kernel, KernelMethod}; +/// +/// fn main() -> Result<(), Box> { +/// // load Iris plant dataset +/// let dataset = linfa_datasets::iris(); +/// +/// let kernel = Kernel::params() +/// .method(KernelMethod::Gaussian(1.0)) +/// .transform(dataset.records().view()); +/// +/// let kernel = HierarchicalCluster::default() +/// .num_clusters(3) +/// .transform(kernel); +/// +/// for (id, target) in kernel.targets().iter().zip(dataset.targets().into_iter()) { +/// let name = match *target as usize { +/// 0 => "setosa", +/// 1 => "versicolor", +/// 2 => "virginica", +/// _ => unreachable!(), +/// }; +/// +/// print!("({} {}) ", id, name); +/// } +/// println!(); +/// +/// Ok(()) +/// } +/// ``` + pub struct HierarchicalCluster { method: Method, stopping: Criterion, @@ -77,13 +118,14 @@ impl HierarchicalCluster { } } -impl Transformer, DatasetBase, Vec>> +/// Predict cluster assignements with a kernel operator +impl PredictRef, Array1> for HierarchicalCluster { /// Perform hierarchical clustering of a similarity matrix /// /// Returns the class id for each data point - fn transform(&self, kernel: Kernel) -> DatasetBase, Vec> { + fn predict_ref(&self, kernel: &Kernel) -> Array1 { // ignore all similarities below this value let threshold = F::cast(1e-6); @@ -145,19 +187,7 @@ impl Transformer, DatasetBase, Vec>> } // return node_index -> cluster_index map - DatasetBase::new(kernel, tmp) - } -} - -impl Transformer, T>, DatasetBase, Vec>> - for HierarchicalCluster -{ - /// Perform hierarchical clustering of a similarity matrix - /// - /// Returns the class id for each data point - fn transform(&self, dataset: DatasetBase, T>) -> DatasetBase, Vec> { - //let Dataset { records, .. } = dataset; - self.transform(dataset.records) + Array1::from(tmp) } } @@ -174,7 +204,8 @@ impl Default for HierarchicalCluster { #[cfg(test)] mod tests { - use linfa::traits::Transformer; + use linfa::traits::{Transformer, Predict}; + use linfa::Dataset; use linfa_kernel::{Kernel, KernelMethod}; use ndarray::{Array, Axis}; use ndarray_rand::{rand_distr::Normal, RandomExt}; @@ -199,12 +230,11 @@ mod tests { .method(KernelMethod::Gaussian(5.0)) .transform(entries.view()); - let kernel = HierarchicalCluster::default() + let ids = HierarchicalCluster::default() .max_distance(0.1) - .transform(kernel); + .predict_ref(&kernel); // check that all assigned ids are equal for the first cluster - let ids = kernel.targets(); let first_cluster_id = &ids[0]; assert!(ids .iter() diff --git a/algorithms/linfa-kernel/src/lib.rs b/algorithms/linfa-kernel/src/lib.rs index cac55ab11..0904b2b6f 100644 --- a/algorithms/linfa-kernel/src/lib.rs +++ b/algorithms/linfa-kernel/src/lib.rs @@ -25,11 +25,14 @@ use serde_crate::{Deserialize, Serialize}; use sprs::{CsMat, CsMatView}; use std::ops::Mul; +pub use linfa::Float; + use linfa::{ dataset::AsTargets, dataset::DatasetBase, dataset::FromTargetArray, dataset::Records, - traits::Transformer, Float, + traits::Transformer, }; + /// Kernel representation, can be either dense or sparse #[derive(Clone)] pub enum KernelType { @@ -234,18 +237,6 @@ impl<'a, F: Float> KernelView<'a, F> { } } -impl, K2: Inner> Records for KernelBase { - type Elem = F; - - fn nsamples(&self) -> usize { - self.size() - } - - fn nfeatures(&self) -> usize { - self.size() - } -} - /// The inner product definition used by a kernel. /// /// There are three methods available: @@ -549,6 +540,18 @@ fn sparse_from_fn>( data } +impl, K2: Inner> Records for KernelBase { + type Elem = F; + + fn nsamples(&self) -> usize { + self.size() + } + + fn nfeatures(&self) -> usize { + self.size() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/dataset/impl_dataset.rs b/src/dataset/impl_dataset.rs index 4d9556ba1..3b3f8ea4f 100644 --- a/src/dataset/impl_dataset.rs +++ b/src/dataset/impl_dataset.rs @@ -1091,6 +1091,33 @@ where } } +#[cfg(feature = "linfa-kernel")] +mod predict_kernels_impl { + use linfa_kernel::{Kernel, KernelBase, Inner}; + use crate::traits::{Predict, PredictRef}; + use linfa_kernel::{DatasetBase, dataset::Records}; + use linfa_kernel::Float; + + impl Predict, DatasetBase, T>> for O + where + O: PredictRef, T>, + { + fn predict(&self, records: Kernel) -> DatasetBase, T> { + let new_targets = self.predict_ref(&records); + DatasetBase::new(records, new_targets) + } + } + + impl<'a, F: Float, T, O> Predict<&'a Kernel, T> for O + where + O: PredictRef, T>, + { + fn predict(&self, records: &'a Kernel) -> T { + self.predict_ref(records) + } + } +} + impl> CountedTargets { pub fn new(targets: S) -> Self { let labels = targets.label_count();