Skip to content

Commit 02654d4

Browse files
committed
shifted
1 parent cc5df2c commit 02654d4

File tree

1 file changed

+16
-21
lines changed

1 file changed

+16
-21
lines changed

src/dist/shifted_prior.rs

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::data::{DataOrSuffStat, ShiftedSuffStat};
2+
use crate::data::extract_stat_then;
23
use crate::dist::Shifted;
34
use crate::traits::*;
45
use rand::Rng;
@@ -141,22 +142,21 @@ where
141142
ShiftedSuffStat::new(parent_stat, self.shift)
142143
}
143144

145+
fn posterior_from_suffstat(
146+
&self,
147+
stat: &ShiftedSuffStat<Fx::Stat>,
148+
) -> Self::Posterior {
149+
ShiftedPrior::new_unchecked(
150+
self.parent.posterior_from_suffstat(&stat.parent()),
151+
self.shift,
152+
)
153+
}
154+
144155
fn posterior(
145156
&self,
146157
x: &DataOrSuffStat<f64, Shifted<Fx>>,
147158
) -> Self::Posterior {
148-
// For now, we'll just compute a new posterior with the same parameters
149-
// In the future, we should implement proper handling of the data
150-
let data: Vec<f64> = match x {
151-
DataOrSuffStat::Data(xs) => {
152-
xs.iter().map(|&x| x - self.shift).collect()
153-
}
154-
DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now
155-
};
156-
157-
let posterior_parent =
158-
self.parent.posterior(&DataOrSuffStat::Data(&data));
159-
Self::new_unchecked(posterior_parent, self.shift)
159+
extract_stat_then(self, x, |stat| self.posterior_from_suffstat(stat))
160160
}
161161

162162
fn ln_m_cache(&self) -> Self::MCache {
@@ -184,15 +184,10 @@ where
184184
&self,
185185
x: &DataOrSuffStat<f64, Shifted<Fx>>,
186186
) -> Self::PpCache {
187-
// For now, we'll just compute from data
188-
let data: Vec<f64> = match x {
189-
DataOrSuffStat::Data(xs) => {
190-
xs.iter().map(|&x| x - self.shift).collect()
191-
}
192-
DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now
193-
};
194-
195-
self.parent.ln_pp_cache(&DataOrSuffStat::Data(&data))
187+
extract_stat_then(self, x, |stat| {
188+
self.parent
189+
.ln_pp_cache(&DataOrSuffStat::SuffStat(stat.parent()))
190+
})
196191
}
197192

198193
fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 {

0 commit comments

Comments
 (0)