Skip to content

Commit dbcf0f0

Browse files
authored
extract_stat (#37)
* progress * typos * cargo fmt * progress * cleanup * cleanup * cargo fmt * inline
1 parent 6f58c99 commit dbcf0f0

File tree

4 files changed

+24
-24
lines changed

4 files changed

+24
-24
lines changed

src/data/mod.rs

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -216,38 +216,36 @@ where
216216
}
217217
}
218218

219-
/// Convert a `DataOrSuffStat` into a `Stat`
220219
#[inline]
221220
pub fn extract_stat<X, Fx, Pr>(pr: &Pr, x: &DataOrSuffStat<X, Fx>) -> Fx::Stat
222221
where
223222
Fx: HasSuffStat<X> + HasDensity<X>,
224223
Fx::Stat: Clone,
225224
Pr: ConjugatePrior<X, Fx>,
226225
{
227-
match x {
228-
DataOrSuffStat::SuffStat(s) => (*s).clone(),
229-
DataOrSuffStat::Data(xs) => {
230-
let mut stat = pr.empty_stat();
231-
stat.observe_many(xs);
232-
stat
233-
}
234-
}
226+
extract_stat_then(pr, x, |s| s.clone())
235227
}
236228

237229
/// Convert a `DataOrSuffStat` into a `Stat` then do something with it
230+
#[inline]
238231
pub fn extract_stat_then<X, Fx, Pr, Fnx, Y>(
239232
pr: &Pr,
240233
x: &DataOrSuffStat<X, Fx>,
241234
f_stat: Fnx,
242235
) -> Y
243236
where
244237
Fx: HasSuffStat<X> + HasDensity<X>,
245-
Fx::Stat: Clone,
246238
Pr: ConjugatePrior<X, Fx>,
247-
Fnx: Fn(Fx::Stat) -> Y,
239+
Fnx: Fn(&Fx::Stat) -> Y,
248240
{
249-
let stat = extract_stat(pr, x);
250-
f_stat(stat)
241+
match x {
242+
DataOrSuffStat::SuffStat(s) => f_stat(s),
243+
DataOrSuffStat::Data(xs) => {
244+
let mut stat = pr.empty_stat();
245+
stat.observe_many(xs);
246+
f_stat(&stat)
247+
}
248+
}
251249
}
252250

253251
#[cfg(test)]
@@ -649,9 +647,10 @@ mod tests {
649647
let data: DataOrSuffStat<f64, Gaussian> =
650648
DataOrSuffStat::Data(&data_vec);
651649

652-
let result = extract_stat_then(&pr, &data, |stat| {
653-
stat.n() * 10 + (stat.sum_x() as usize)
654-
});
650+
let result =
651+
extract_stat_then(&pr, &data, |stat: &GaussianSuffStat| {
652+
stat.n() * 10 + (stat.sum_x() as usize)
653+
});
655654

656655
assert_eq!(result, 36); // 3 * 10 + 6
657656
}

src/dist/dirichlet/categorical_prior.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ impl<X: CategoricalDatum> ConjugatePrior<X, Categorical>
3131
}
3232

3333
fn posterior(&self, x: &CategoricalData<X>) -> Self::Posterior {
34-
extract_stat_then(self, x, |stat: CategoricalSuffStat| {
34+
extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
3535
let alphas: Vec<f64> =
3636
stat.counts().iter().map(|&ct| self.alpha() + ct).collect();
3737

@@ -54,7 +54,7 @@ impl<X: CategoricalDatum> ConjugatePrior<X, Categorical>
5454
) -> f64 {
5555
let sum_alpha = self.alpha() * self.k() as f64;
5656

57-
extract_stat_then(self, x, |stat: CategoricalSuffStat| {
57+
extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
5858
let b = ln_gammafn(sum_alpha + stat.n() as f64);
5959
let c = stat
6060
.counts()
@@ -101,7 +101,7 @@ impl<X: CategoricalDatum> ConjugatePrior<X, Categorical> for Dirichlet {
101101
}
102102

103103
fn posterior(&self, x: &CategoricalData<X>) -> Self::Posterior {
104-
extract_stat_then(self, x, |stat: CategoricalSuffStat| {
104+
extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
105105
let alphas: Vec<f64> = self
106106
.alphas()
107107
.iter()
@@ -130,7 +130,7 @@ impl<X: CategoricalDatum> ConjugatePrior<X, Categorical> for Dirichlet {
130130
x: &CategoricalData<X>,
131131
) -> f64 {
132132
let (sum_alpha, ln_norm) = cache;
133-
extract_stat_then(self, x, |stat: CategoricalSuffStat| {
133+
extract_stat_then(self, x, |stat: &CategoricalSuffStat| {
134134
let b = ln_gammafn(sum_alpha + stat.n() as f64);
135135
let c = self
136136
.alphas()

src/dist/niw/mvg_prior.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::consts::LN_2PI;
2-
use crate::data::{extract_stat_then, DataOrSuffStat, MvGaussianSuffStat};
2+
use crate::data::extract_stat_then;
3+
use crate::data::{DataOrSuffStat, MvGaussianSuffStat};
34
use crate::dist::{MvGaussian, NormalInvWishart};
45
use crate::misc::lnmv_gamma;
56
use crate::traits::ConjugatePrior;
@@ -36,7 +37,7 @@ impl ConjugatePrior<DVector<f64>, MvGaussian> for NormalInvWishart {
3637
}
3738

3839
let nf = x.n() as f64;
39-
extract_stat_then(self, x, |stat: MvGaussianSuffStat| {
40+
extract_stat_then(self, x, |stat: &MvGaussianSuffStat| {
4041
let xbar = stat.sum_x() / stat.n() as f64;
4142
let diff = &xbar - self.mu();
4243
// s = \sum_{i=1}^N (x_i - \bar{x}) (x_i - \bar{x})^T

src/dist/normal_inv_chi_squared/gaussian_prior.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ impl ConjugatePrior<f64, Gaussian> for NormalInvChiSquared {
7070
}
7171

7272
fn posterior(&self, x: &DataOrSuffStat<f64, Gaussian>) -> Self {
73-
extract_stat_then(self, x, |stat: GaussianSuffStat| {
73+
extract_stat_then(self, x, |stat: &GaussianSuffStat| {
7474
posterior_from_stat(self, &stat).into()
7575
})
7676
}
@@ -85,7 +85,7 @@ impl ConjugatePrior<f64, Gaussian> for NormalInvChiSquared {
8585
cache: &Self::MCache,
8686
x: &DataOrSuffStat<f64, Gaussian>,
8787
) -> f64 {
88-
extract_stat_then(self, x, |stat: GaussianSuffStat| {
88+
extract_stat_then(self, x, |stat: &GaussianSuffStat| {
8989
let n = stat.n() as f64;
9090
let post: Self = posterior_from_stat(self, &stat).into();
9191
let lnz_n = post.ln_z();

0 commit comments

Comments
 (0)