|
1 | 1 | use crate::data::{DataOrSuffStat, ShiftedSuffStat}; |
| 2 | +use crate::data::extract_stat_then; |
2 | 3 | use crate::dist::Shifted; |
3 | 4 | use crate::traits::*; |
4 | 5 | use rand::Rng; |
@@ -141,22 +142,21 @@ where |
141 | 142 | ShiftedSuffStat::new(parent_stat, self.shift) |
142 | 143 | } |
143 | 144 |
|
| 145 | + fn posterior_from_suffstat( |
| 146 | + &self, |
| 147 | + stat: &ShiftedSuffStat<Fx::Stat>, |
| 148 | + ) -> Self::Posterior { |
| 149 | + ShiftedPrior::new_unchecked( |
| 150 | + self.parent.posterior_from_suffstat(&stat.parent()), |
| 151 | + self.shift, |
| 152 | + ) |
| 153 | + } |
| 154 | + |
144 | 155 | fn posterior( |
145 | 156 | &self, |
146 | 157 | x: &DataOrSuffStat<f64, Shifted<Fx>>, |
147 | 158 | ) -> Self::Posterior { |
148 | | - // For now, we'll just compute a new posterior with the same parameters |
149 | | - // In the future, we should implement proper handling of the data |
150 | | - let data: Vec<f64> = match x { |
151 | | - DataOrSuffStat::Data(xs) => { |
152 | | - xs.iter().map(|&x| x - self.shift).collect() |
153 | | - } |
154 | | - DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now |
155 | | - }; |
156 | | - |
157 | | - let posterior_parent = |
158 | | - self.parent.posterior(&DataOrSuffStat::Data(&data)); |
159 | | - Self::new_unchecked(posterior_parent, self.shift) |
| 159 | + extract_stat_then(self, x, |stat| self.posterior_from_suffstat(stat)) |
160 | 160 | } |
161 | 161 |
|
162 | 162 | fn ln_m_cache(&self) -> Self::MCache { |
@@ -184,15 +184,10 @@ where |
184 | 184 | &self, |
185 | 185 | x: &DataOrSuffStat<f64, Shifted<Fx>>, |
186 | 186 | ) -> Self::PpCache { |
187 | | - // For now, we'll just compute from data |
188 | | - let data: Vec<f64> = match x { |
189 | | - DataOrSuffStat::Data(xs) => { |
190 | | - xs.iter().map(|&x| x - self.shift).collect() |
191 | | - } |
192 | | - DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now |
193 | | - }; |
194 | | - |
195 | | - self.parent.ln_pp_cache(&DataOrSuffStat::Data(&data)) |
| 187 | + extract_stat_then(self, x, |stat| { |
| 188 | + self.parent |
| 189 | + .ln_pp_cache(&DataOrSuffStat::SuffStat(stat.parent())) |
| 190 | + }) |
196 | 191 | } |
197 | 192 |
|
198 | 193 | fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { |
|
0 commit comments