|
| 1 | +use crate::EnsembleLearnerValidParams; |
| 2 | +use linfa::{ |
| 3 | + dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records}, |
| 4 | + error::Error, |
| 5 | + traits::*, |
| 6 | + DatasetBase, |
| 7 | +}; |
| 8 | +use ndarray::{Array2, Axis, Zip}; |
| 9 | +use rand::Rng; |
| 10 | +use std::{cmp::Eq, collections::HashMap, hash::Hash}; |
| 11 | + |
| 12 | +pub struct EnsembleLearner<M> { |
| 13 | + pub models: Vec<M>, |
| 14 | +} |
| 15 | + |
| 16 | +impl<M> EnsembleLearner<M> { |
| 17 | + // Generates prediction iterator returning predictions from each model |
| 18 | + pub fn generate_predictions<'b, R: Records, T>( |
| 19 | + &'b self, |
| 20 | + x: &'b R, |
| 21 | + ) -> impl Iterator<Item = T> + 'b |
| 22 | + where |
| 23 | + M: Predict<&'b R, T>, |
| 24 | + { |
| 25 | + self.models.iter().map(move |m| m.predict(x)) |
| 26 | + } |
| 27 | +} |
| 28 | + |
| 29 | +impl<F: Clone, T, M> PredictInplace<Array2<F>, T> for EnsembleLearner<M> |
| 30 | +where |
| 31 | + M: PredictInplace<Array2<F>, T>, |
| 32 | + <T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug, |
| 33 | + T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>, |
| 34 | +{ |
| 35 | + fn predict_inplace(&self, x: &Array2<F>, y: &mut T) { |
| 36 | + let y_array = y.as_targets(); |
| 37 | + assert_eq!( |
| 38 | + x.nrows(), |
| 39 | + y_array.len_of(Axis(0)), |
| 40 | + "The number of data points must match the number of outputs." |
| 41 | + ); |
| 42 | + |
| 43 | + let predictions = self.generate_predictions(x); |
| 44 | + |
| 45 | + // prediction map has same shape as y_array, but the elements are maps |
| 46 | + let mut prediction_maps = y_array.map(|_| HashMap::new()); |
| 47 | + |
| 48 | + for prediction in predictions { |
| 49 | + let p_arr = prediction.as_targets(); |
| 50 | + assert_eq!(p_arr.shape(), y_array.shape()); |
| 51 | + // Insert each prediction value into the corresponding map |
| 52 | + Zip::from(&mut prediction_maps) |
| 53 | + .and(&p_arr) |
| 54 | + .for_each(|map, val| *map.entry(*val).or_insert(0) += 1); |
| 55 | + } |
| 56 | + |
| 57 | + // For each prediction, pick the result with the highest number of votes |
| 58 | + let agg_preds = prediction_maps.map(|map| map.iter().max_by_key(|(_, v)| **v).unwrap().0); |
| 59 | + let mut y_array = y.as_targets_mut(); |
| 60 | + for (y, pred) in y_array.iter_mut().zip(agg_preds.iter()) { |
| 61 | + *y = **pred |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + fn default_target(&self, x: &Array2<F>) -> T { |
| 66 | + self.models[0].default_target(x) |
| 67 | + } |
| 68 | +} |
| 69 | + |
| 70 | +impl<D, T, P: Fit<Array2<D>, T::Owned, Error>, R: Rng + Clone> Fit<Array2<D>, T, Error> |
| 71 | + for EnsembleLearnerValidParams<P, R> |
| 72 | +where |
| 73 | + D: Clone, |
| 74 | + T: FromTargetArrayOwned, |
| 75 | + T::Elem: Copy + Eq + Hash, |
| 76 | + T::Owned: AsTargets, |
| 77 | +{ |
| 78 | + type Object = EnsembleLearner<P::Object>; |
| 79 | + |
| 80 | + fn fit( |
| 81 | + &self, |
| 82 | + dataset: &DatasetBase<Array2<D>, T>, |
| 83 | + ) -> core::result::Result<Self::Object, Error> { |
| 84 | + let mut models = Vec::new(); |
| 85 | + let mut rng = self.rng.clone(); |
| 86 | + |
| 87 | + let dataset_size = |
| 88 | + ((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize; |
| 89 | + |
| 90 | + let iter = dataset.bootstrap_samples(dataset_size, &mut rng); |
| 91 | + |
| 92 | + for train in iter { |
| 93 | + let model = self.model_params.fit(&train).unwrap(); |
| 94 | + models.push(model); |
| 95 | + |
| 96 | + if models.len() == self.ensemble_size { |
| 97 | + break; |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + Ok(EnsembleLearner { models }) |
| 102 | + } |
| 103 | +} |
0 commit comments