- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 300
Description
📝 Description
I would like to contribute a new module to the linfa-trees crate that implements the Random Forest algorithm for classification tasks. This will expand linfa-trees from single decision trees into ensemble learning, aligning closely with scikit-learn's functionality in Python.
🚀 Motivation
Random Forests are a powerful ensemble learning method used widely in classification tasks. They provide:
- 
Robustness to overfitting 
- 
Better generalization than single trees 
- 
Feature importance estimates 
Currently, linfa-trees provides support for single decision trees. By adding Random Forests, we unlock ensemble learning for the Rust ML ecosystem.
📐 Proposed Design
🔹 New Module
A new file will be added:
bashCopyEditlinfa-trees/src/decision_trees/random_forest.rs
This will include:
- 
RandomForestClassifier<F: Float>
- 
RandomForestParams<F>(unchecked)
- 
RandomForestValidParams<F>(checked)
🔹 Trait Implementations
I will implement the following traits according to linfa conventions:
- 
ParamGuardfor parameter validation
- 
Fitto train the forest using bootstrapped data and random feature subsetting
- 
PredictInplaceandPredictto perform inference via majority voting
🔹 Example
An example will be added in:
bashCopyEditlinfa-trees/examples/iris_random_forest.rs
Using the Iris dataset from linfa-datasets.
🔹 Benchmark (Optional)
If approved, I can also add a benchmark using Criterion:
bashCopyEditlinfa-trees/benches/random_forest.rs
📁 File Integration Plan
- 
src/lib.rs: Re-exportrandom_forest::*
- 
src/decision_trees/mod.rs:pub mod random_forest;
- 
README.md: Update with a section on Random Forests and example usage
- 
examples/iris_random_forest.rs: Demonstrates training and evaluation
📦 API Preview
rustCopyEditlet model = RandomForest::params() .n_trees(100) .feature_subsample(0.8) .max_depth(Some(10)) .fit(&dataset)?;
let predictions = model.predict(&dataset);
let acc = predictions.confusion_matrix(&dataset)?.accuracy();
✅ Conformity with CONTRIBUTING.md
- 
Uses Floattrait forf32/f64compatibility
- 
Follows the Params→ValidParamsvalidation pattern
- 
Implements Fit,Predict, andPredictInplaceusingDataset
- 
Optional serdesupport via feature flag
- 
Will include unit tests and optionally benchmarks 
🙋♂️ Request
Please let me know if you're open to this contribution. I’d be happy to align with maintainers on:
- 
Feature scope (classifier first, regressor later?) 
- 
Benchmarking standards 
- 
Integration strategy (e.g., reuse of DecisionTree)
Looking forward to your guidance!