From b61a716e9c8471cfda5cde9d336acad972aae8ea Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Sat, 19 Apr 2025 10:55:51 -0700 Subject: [PATCH 01/17] ScaledPrior --- src/dist/mod.rs | 2 + src/dist/scaled_prior.rs | 303 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 305 insertions(+) create mode 100644 src/dist/scaled_prior.rs diff --git a/src/dist/mod.rs b/src/dist/mod.rs index b6e0a0e1..086ab49a 100644 --- a/src/dist/mod.rs +++ b/src/dist/mod.rs @@ -45,6 +45,7 @@ mod pareto; mod poisson; mod scaled; mod scaled_inv_chi_squared; +mod scaled_prior; mod shifted; mod skellam; mod students_t; @@ -108,6 +109,7 @@ pub use scaled_inv_chi_squared::{ ScaledInvChiSquared, ScaledInvChiSquaredError, ScaledInvChiSquaredParameters, }; +pub use scaled_prior::{ScaledPrior, ScaledPriorError}; pub use shifted::{Shifted, ShiftedError, ShiftedParameters}; pub use skellam::{Skellam, SkellamError, SkellamParameters}; pub use students_t::{StudentsT, StudentsTError, StudentsTParameters}; diff --git a/src/dist/scaled_prior.rs b/src/dist/scaled_prior.rs new file mode 100644 index 00000000..299a7321 --- /dev/null +++ b/src/dist/scaled_prior.rs @@ -0,0 +1,303 @@ +use crate::dist::Scaled; +use crate::traits::*; +use rand::Rng; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::marker::PhantomData; +use crate::data::{DataOrSuffStat, ScaledSuffStat}; + +/// A wrapper for priors that scales the output distribution +/// +/// If drawing a `Pr` gives a distribution `Fx`, then drawing `ScaledPrior` +/// will produce a `Scaled`. +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] +pub struct ScaledPrior +where + Pr: Sampleable, + Fx: Scalable, +{ + parent: Pr, + scale: f64, + rate: f64, + logjac: f64, + _phantom: PhantomData, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ScaledPriorError { + /// The scale parameter must be a normal (finite, non-zero, non-subnormal) number + NonNormalScale(f64), + /// The scale parameter must be positive + NegativeScale(f64), +} + +impl std::error::Error for ScaledPriorError {} + +impl fmt::Display for ScaledPriorError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::NonNormalScale(scale) => { + write!(f, "non-normal scale: {}", scale) + } + Self::NegativeScale(scale) => { + write!(f, "negative scale: {}", scale) + } + } + } +} + +impl ScaledPrior +where + Pr: Sampleable, + Fx: Scalable, +{ + /// Creates a new scaled prior with the given parent prior and scale factor. + /// + /// # Errors + /// Returns `ScaledPriorError::NonNormalScale` if the scale parameter is not a + /// normal number (i.e., if it's zero, infinite, or NaN). + /// + /// Returns `ScaledPriorError::NegativeScale` if the scale parameter is not positive. + pub fn new(parent: Pr, scale: f64) -> Result { + if !scale.is_normal() { + Err(ScaledPriorError::NonNormalScale(scale)) + } else if scale <= 0.0 { + Err(ScaledPriorError::NegativeScale(scale)) + } else { + Ok(ScaledPrior { + parent, + scale, + rate: scale.recip(), + logjac: scale.abs().ln(), + _phantom: PhantomData, + }) + } + } + + /// Creates a new scaled prior with the given parent prior and scale factor, + /// without checking the scale parameter. + /// + /// # Safety + /// The scale parameter must be a positive normal (finite, non-zero, + /// non-subnormal) number. + pub fn new_unchecked(parent: Pr, scale: f64) -> Self { + ScaledPrior { + parent, + scale, + rate: scale.recip(), + logjac: scale.abs().ln(), + _phantom: PhantomData, + } + } + + pub fn parent(&self) -> &Pr { + &self.parent + } + + pub fn parent_mut(&mut self) -> &mut Pr { + &mut self.parent + } + + pub fn scale(&self) -> f64 { + self.scale + } + + pub fn rate(&self) -> f64 { + self.rate + } + + pub fn logjac(&self) -> f64 { + self.logjac + } +} + +impl Sampleable> for ScaledPrior +where + Pr: Sampleable, + Fx: Scalable, +{ + fn draw(&self, rng: &mut R) -> Scaled { + let fx = self.parent.draw(rng); + Scaled::new_unchecked(fx, self.scale) + } +} + +pub struct ScaledPriorParameters { + parent: Pr::Parameters, + scale: f64, +} + +impl Parameterized for ScaledPrior +where + Pr: Sampleable + Parameterized, + Fx: Scalable, +{ + type Parameters = ScaledPriorParameters; + + fn emit_params(&self) -> Self::Parameters { + ScaledPriorParameters { + parent: self.parent.emit_params(), + scale: self.scale, + } + } + + fn from_params(params: Self::Parameters) -> Self { + let parent = Pr::from_params(params.parent); + Self::new_unchecked(parent, params.scale) + } +} + +/// Helper trait to convert between scaled and unscaled data +/// +/// This is used internally by the ConjugatePrior implementation for ScaledPrior. +trait ScaleData { + /// Scale the data by the given factor + fn scale_data(&self, scale: f64) -> Self; +} + +impl + Clone> ScaleData for X { + fn scale_data(&self, scale: f64) -> Self { + self.clone() * scale + } +} + +impl + Clone> ScaleData for Vec { + fn scale_data(&self, scale: f64) -> Self { + self.iter().map(|x| x.clone() * scale).collect() + } +} + +impl + Clone> ScaleData for &[X] { + fn scale_data(&self, scale: f64) -> Self { + &self.iter().map(|x| x.clone() * scale).collect::>()[..] + } +} + +impl ConjugatePrior> for ScaledPrior +where + Pr: ConjugatePrior, + Fx: HasDensity + HasSuffStat + Scalable, + X: std::ops::Mul + Clone, +{ + type Posterior = ScaledPrior; + type MCache = Pr::MCache; + type PpCache = (Pr::PpCache, f64); // Parent cache and scale + + fn empty_stat(&self) -> as HasSuffStat>::Stat { + // For a Scaled, the Stat is ScaledSuffStat + ScaledSuffStat::new(self.parent.empty_stat(), self.scale) + } + + fn posterior(&self, x: &DataOrSuffStat>) -> Self::Posterior { + // We need to convert the data for Scaled into data for Fx + // This means scaling by rate = 1/scale + let parent_data = match x { + DataOrSuffStat::Data(data) => { + // Scale the data by rate + let scaled_data = data.scale_data(self.rate); + DataOrSuffStat::Data(&scaled_data) + } + DataOrSuffStat::SuffStat(stat) => { + // The stat is already set up to handle the scaling + // Just access the parent stat directly + DataOrSuffStat::SuffStat(stat.parent()) + } + }; + + // Get posterior from parent + let parent_posterior = self.parent.posterior(&parent_data); + + // Wrap in ScaledPrior with the same scale + ScaledPrior::new_unchecked(parent_posterior, self.scale) + } + + fn ln_m_cache(&self) -> Self::MCache { + self.parent.ln_m_cache() + } + + fn ln_m_with_cache( + &self, + cache: &Self::MCache, + x: &DataOrSuffStat>, + ) -> f64 { + // We need to convert the data for Scaled into data for Fx + let parent_data = match x { + DataOrSuffStat::Data(data) => { + // Scale the data by rate + let scaled_data = data.scale_data(self.rate); + DataOrSuffStat::Data(&scaled_data) + } + DataOrSuffStat::SuffStat(stat) => { + // The stat is already set up to handle the scaling + // Just access the parent stat directly + DataOrSuffStat::SuffStat(stat.parent()) + } + }; + + // Use parent's ln_m_with_cache + self.parent.ln_m_with_cache(cache, &parent_data) + } + + fn ln_pp_cache(&self, x: &DataOrSuffStat>) -> Self::PpCache { + // We need to convert the data for Scaled into data for Fx + let parent_data = match x { + DataOrSuffStat::Data(data) => { + // Scale the data by rate + let scaled_data = data.scale_data(self.rate); + DataOrSuffStat::Data(&scaled_data) + } + DataOrSuffStat::SuffStat(stat) => { + // The stat is already set up to handle the scaling + // Just access the parent stat directly + DataOrSuffStat::SuffStat(stat.parent()) + } + }; + + // Get cache from parent and save our scale + (self.parent.ln_pp_cache(&parent_data), self.scale) + } + + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 { + // Unpack the cache + let (parent_cache, scale) = cache; + + // Scale y by rate to get into parent's space + let scaled_y = y.clone() * self.rate; + + // Use parent's ln_pp_with_cache + self.parent.ln_pp_with_cache(parent_cache, &scaled_y) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dist::{Gaussian, NormalInvChiSquared}; + use rand::SeedableRng; + use rand_xoshiro::Xoshiro256Plus; + + #[test] + fn test_scaled_prior_draw() { + let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); + let scaled_prior = ScaledPrior::new(prior, 2.0).unwrap(); + + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let dist = scaled_prior.draw(&mut rng); + + assert_eq!(dist.scale(), 2.0); + } + + #[test] + fn test_scale_data() { + let x = 2.0; + let scaled = x.scale_data(3.0); + assert_eq!(scaled, 6.0); + + let vec = vec![1.0, 2.0, 3.0]; + let scaled_vec = vec.scale_data(2.0); + assert_eq!(scaled_vec, vec![2.0, 4.0, 6.0]); + } +} \ No newline at end of file From ed6fd3f3ce24937fbda549a0703c0845df6b4086 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Sat, 19 Apr 2025 11:06:57 -0700 Subject: [PATCH 02/17] update --- src/dist/scaled_prior.rs | 180 +++++++++++++++++---------------------- 1 file changed, 79 insertions(+), 101 deletions(-) diff --git a/src/dist/scaled_prior.rs b/src/dist/scaled_prior.rs index 299a7321..c946b7e8 100644 --- a/src/dist/scaled_prior.rs +++ b/src/dist/scaled_prior.rs @@ -150,68 +150,31 @@ where } } -/// Helper trait to convert between scaled and unscaled data -/// -/// This is used internally by the ConjugatePrior implementation for ScaledPrior. -trait ScaleData { - /// Scale the data by the given factor - fn scale_data(&self, scale: f64) -> Self; -} - -impl + Clone> ScaleData for X { - fn scale_data(&self, scale: f64) -> Self { - self.clone() * scale - } -} - -impl + Clone> ScaleData for Vec { - fn scale_data(&self, scale: f64) -> Self { - self.iter().map(|x| x.clone() * scale).collect() - } -} - -impl + Clone> ScaleData for &[X] { - fn scale_data(&self, scale: f64) -> Self { - &self.iter().map(|x| x.clone() * scale).collect::>()[..] - } -} - -impl ConjugatePrior> for ScaledPrior +impl ConjugatePrior> for ScaledPrior where - Pr: ConjugatePrior, - Fx: HasDensity + HasSuffStat + Scalable, - X: std::ops::Mul + Clone, + Pr: ConjugatePrior, + Fx: HasSuffStat + Scalable + HasDensity, + Scaled: HasSuffStat>, { - type Posterior = ScaledPrior; + type Posterior = Self; type MCache = Pr::MCache; - type PpCache = (Pr::PpCache, f64); // Parent cache and scale + type PpCache = Pr::PpCache; - fn empty_stat(&self) -> as HasSuffStat>::Stat { - // For a Scaled, the Stat is ScaledSuffStat - ScaledSuffStat::new(self.parent.empty_stat(), self.scale) + fn empty_stat(&self) -> ScaledSuffStat { + let parent_stat = self.parent.empty_stat(); + ScaledSuffStat::new(parent_stat, self.scale) } - fn posterior(&self, x: &DataOrSuffStat>) -> Self::Posterior { - // We need to convert the data for Scaled into data for Fx - // This means scaling by rate = 1/scale - let parent_data = match x { - DataOrSuffStat::Data(data) => { - // Scale the data by rate - let scaled_data = data.scale_data(self.rate); - DataOrSuffStat::Data(&scaled_data) - } - DataOrSuffStat::SuffStat(stat) => { - // The stat is already set up to handle the scaling - // Just access the parent stat directly - DataOrSuffStat::SuffStat(stat.parent()) - } + fn posterior(&self, x: &DataOrSuffStat>) -> Self::Posterior { + // For now, we'll just compute a new posterior with the same parameters + // In the future, we should implement proper handling of the data + let data: Vec = match x { + DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x * self.rate).collect(), + DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now }; - // Get posterior from parent - let parent_posterior = self.parent.posterior(&parent_data); - - // Wrap in ScaledPrior with the same scale - ScaledPrior::new_unchecked(parent_posterior, self.scale) + let posterior_parent = self.parent.posterior(&DataOrSuffStat::Data(&data)); + Self::new_unchecked(posterior_parent, self.scale) } fn ln_m_cache(&self) -> Self::MCache { @@ -221,61 +184,42 @@ where fn ln_m_with_cache( &self, cache: &Self::MCache, - x: &DataOrSuffStat>, + x: &DataOrSuffStat>, ) -> f64 { - // We need to convert the data for Scaled into data for Fx - let parent_data = match x { - DataOrSuffStat::Data(data) => { - // Scale the data by rate - let scaled_data = data.scale_data(self.rate); - DataOrSuffStat::Data(&scaled_data) - } - DataOrSuffStat::SuffStat(stat) => { - // The stat is already set up to handle the scaling - // Just access the parent stat directly - DataOrSuffStat::SuffStat(stat.parent()) - } + // For now, we'll just compute from data + let data: Vec = match x { + DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x * self.rate).collect(), + DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now }; - // Use parent's ln_m_with_cache - self.parent.ln_m_with_cache(cache, &parent_data) + self.parent.ln_m_with_cache(cache, &DataOrSuffStat::Data(&data)) } - fn ln_pp_cache(&self, x: &DataOrSuffStat>) -> Self::PpCache { - // We need to convert the data for Scaled into data for Fx - let parent_data = match x { - DataOrSuffStat::Data(data) => { - // Scale the data by rate - let scaled_data = data.scale_data(self.rate); - DataOrSuffStat::Data(&scaled_data) - } - DataOrSuffStat::SuffStat(stat) => { - // The stat is already set up to handle the scaling - // Just access the parent stat directly - DataOrSuffStat::SuffStat(stat.parent()) - } + fn ln_pp_cache(&self, x: &DataOrSuffStat>) -> Self::PpCache { + // For now, we'll just compute from data + let data: Vec = match x { + DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x * self.rate).collect(), + DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now }; - // Get cache from parent and save our scale - (self.parent.ln_pp_cache(&parent_data), self.scale) + self.parent.ln_pp_cache(&DataOrSuffStat::Data(&data)) } - fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 { - // Unpack the cache - let (parent_cache, scale) = cache; - - // Scale y by rate to get into parent's space - let scaled_y = y.clone() * self.rate; - - // Use parent's ln_pp_with_cache - self.parent.ln_pp_with_cache(parent_cache, &scaled_y) + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { + // Scale y back to the parent distribution's space + let scaled_y = *y * self.rate; + // Compute the log posterior predictive using the parent + // Add the log Jacobian adjustment for the scale + self.parent.ln_pp_with_cache(cache, &scaled_y) - self.logjac } } #[cfg(test)] mod tests { use super::*; - use crate::dist::{Gaussian, NormalInvChiSquared}; + use crate::data::DataOrSuffStat; + use crate::dist::{Gaussian, NormalInvChiSquared, Scaled}; + use crate::traits::*; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; @@ -289,15 +233,49 @@ mod tests { assert_eq!(dist.scale(), 2.0); } + + #[test] + fn test_scaled_prior_conjugate() { + let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); + let scaled_prior = ScaledPrior::new(prior, 2.0).unwrap(); + + // Create an empty stat to test + let stat = scaled_prior.empty_stat(); + assert_eq!(stat.scale(), 2.0); + + // Test posterior with empty data + let data: Vec = Vec::new(); + // Manually create DataOrSuffStat instead of using .into() + let dos = DataOrSuffStat::Data(&data); + let posterior = scaled_prior.posterior(&dos); + + // Scale should persist through posterior computation + assert_eq!(posterior.scale(), 2.0); + } #[test] - fn test_scale_data() { - let x = 2.0; - let scaled = x.scale_data(3.0); - assert_eq!(scaled, 6.0); + fn test_scaled_prior_with_data() { + let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); + let scaled_prior = ScaledPrior::new(prior, 2.0).unwrap(); + + // Create some data - will be scaled by 1/2 internally for parent calculations + let data = vec![2.0, 4.0, 6.0]; + + // Manually create DataOrSuffStat instead of using .into() + let dos = DataOrSuffStat::Data(&data); + + // Compute posterior + let posterior = scaled_prior.posterior(&dos); + + // Scale should persist through posterior computation + assert_eq!(posterior.scale(), 2.0); + + // Verify ln_m and ln_pp work + let ln_m = scaled_prior.ln_m(&dos); + let ln_pp = scaled_prior.ln_pp(&2.0, &dos); - let vec = vec![1.0, 2.0, 3.0]; - let scaled_vec = vec.scale_data(2.0); - assert_eq!(scaled_vec, vec![2.0, 4.0, 6.0]); + // Values should be finite (actual values will depend on implementation) + assert!(ln_m.is_finite()); + assert!(ln_pp.is_finite()); } } \ No newline at end of file From 2440fdc995fd9c107f94aa2dc541a2853de9e4d3 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 21 Apr 2025 07:17:51 -0700 Subject: [PATCH 03/17] shifted --- src/dist/mod.rs | 2 + src/dist/shifted_prior.rs | 256 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+) create mode 100644 src/dist/shifted_prior.rs diff --git a/src/dist/mod.rs b/src/dist/mod.rs index 086ab49a..1b78a998 100644 --- a/src/dist/mod.rs +++ b/src/dist/mod.rs @@ -47,6 +47,7 @@ mod scaled; mod scaled_inv_chi_squared; mod scaled_prior; mod shifted; +mod shifted_prior; mod skellam; mod students_t; mod uniform; @@ -111,6 +112,7 @@ pub use scaled_inv_chi_squared::{ }; pub use scaled_prior::{ScaledPrior, ScaledPriorError}; pub use shifted::{Shifted, ShiftedError, ShiftedParameters}; +pub use shifted_prior::{ShiftedPrior, ShiftedPriorError}; pub use skellam::{Skellam, SkellamError, SkellamParameters}; pub use students_t::{StudentsT, StudentsTError, StudentsTParameters}; pub use uniform::{Uniform, UniformError, UniformParameters}; diff --git a/src/dist/shifted_prior.rs b/src/dist/shifted_prior.rs new file mode 100644 index 00000000..a0123c1e --- /dev/null +++ b/src/dist/shifted_prior.rs @@ -0,0 +1,256 @@ +use crate::dist::Shifted; +use crate::traits::*; +use rand::Rng; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::marker::PhantomData; +use crate::data::{DataOrSuffStat, ShiftedSuffStat}; + +/// A wrapper for priors that shifts the output distribution +/// +/// If drawing a `Pr` gives a distribution `Fx`, then drawing `ShiftedPrior` +/// will produce a `Shifted`. +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] +pub struct ShiftedPrior +where + Pr: Sampleable, + Fx: Shiftable, +{ + parent: Pr, + shift: f64, + _phantom: PhantomData, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ShiftedPriorError { + /// The shift parameter must be a finite number + NonFiniteShift(f64), +} + +impl std::error::Error for ShiftedPriorError {} + +impl fmt::Display for ShiftedPriorError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::NonFiniteShift(shift) => { + write!(f, "non-finite shift: {}", shift) + } + } + } +} + +impl ShiftedPrior +where + Pr: Sampleable, + Fx: Shiftable, +{ + /// Creates a new shifted prior with the given parent prior and shift value. + /// + /// # Errors + /// Returns `ShiftedPriorError::NonFiniteShift` if the shift parameter is not + /// a finite number (i.e., if it's infinite or NaN). + pub fn new(parent: Pr, shift: f64) -> Result { + if !shift.is_finite() { + Err(ShiftedPriorError::NonFiniteShift(shift)) + } else { + Ok(ShiftedPrior { + parent, + shift, + _phantom: PhantomData, + }) + } + } + + /// Creates a new shifted prior with the given parent prior and shift value, + /// without checking the shift parameter. + /// + /// # Safety + /// The shift parameter must be a finite number. + pub fn new_unchecked(parent: Pr, shift: f64) -> Self { + ShiftedPrior { + parent, + shift, + _phantom: PhantomData, + } + } + + pub fn parent(&self) -> &Pr { + &self.parent + } + + pub fn parent_mut(&mut self) -> &mut Pr { + &mut self.parent + } + + pub fn shift(&self) -> f64 { + self.shift + } +} + +impl Sampleable> for ShiftedPrior +where + Pr: Sampleable, + Fx: Shiftable, +{ + fn draw(&self, rng: &mut R) -> Shifted { + let fx = self.parent.draw(rng); + Shifted::new_unchecked(fx, self.shift) + } +} + +pub struct ShiftedPriorParameters { + parent: Pr::Parameters, + shift: f64, +} + +impl Parameterized for ShiftedPrior +where + Pr: Sampleable + Parameterized, + Fx: Shiftable, +{ + type Parameters = ShiftedPriorParameters; + + fn emit_params(&self) -> Self::Parameters { + ShiftedPriorParameters { + parent: self.parent.emit_params(), + shift: self.shift, + } + } + + fn from_params(params: Self::Parameters) -> Self { + let parent = Pr::from_params(params.parent); + Self::new_unchecked(parent, params.shift) + } +} + +impl ConjugatePrior> for ShiftedPrior +where + Pr: ConjugatePrior, + Fx: HasSuffStat + Shiftable + HasDensity, + Shifted: HasSuffStat>, +{ + type Posterior = Self; + type MCache = Pr::MCache; + type PpCache = Pr::PpCache; + + fn empty_stat(&self) -> ShiftedSuffStat { + let parent_stat = self.parent.empty_stat(); + ShiftedSuffStat::new(parent_stat, self.shift) + } + + fn posterior(&self, x: &DataOrSuffStat>) -> Self::Posterior { + // For now, we'll just compute a new posterior with the same parameters + // In the future, we should implement proper handling of the data + let data: Vec = match x { + DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x - self.shift).collect(), + DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now + }; + + let posterior_parent = self.parent.posterior(&DataOrSuffStat::Data(&data)); + Self::new_unchecked(posterior_parent, self.shift) + } + + fn ln_m_cache(&self) -> Self::MCache { + self.parent.ln_m_cache() + } + + fn ln_m_with_cache( + &self, + cache: &Self::MCache, + x: &DataOrSuffStat>, + ) -> f64 { + // For now, we'll just compute from data + let data: Vec = match x { + DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x - self.shift).collect(), + DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now + }; + + self.parent.ln_m_with_cache(cache, &DataOrSuffStat::Data(&data)) + } + + fn ln_pp_cache(&self, x: &DataOrSuffStat>) -> Self::PpCache { + // For now, we'll just compute from data + let data: Vec = match x { + DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x - self.shift).collect(), + DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now + }; + + self.parent.ln_pp_cache(&DataOrSuffStat::Data(&data)) + } + + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { + // Shift y back to the parent distribution's space + let shifted_y = *y - self.shift; + // Compute the log posterior predictive using the parent + self.parent.ln_pp_with_cache(cache, &shifted_y) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::data::DataOrSuffStat; + use crate::dist::{Gaussian, NormalInvChiSquared, Shifted}; + use crate::traits::*; + use rand::SeedableRng; + use rand_xoshiro::Xoshiro256Plus; + + #[test] + fn test_shifted_prior_draw() { + let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); + let shifted_prior = ShiftedPrior::new(prior, 2.0).unwrap(); + + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let dist = shifted_prior.draw(&mut rng); + + assert_eq!(dist.shift(), 2.0); + } + + #[test] + fn test_shifted_prior_conjugate() { + let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); + let shifted_prior = ShiftedPrior::new(prior, 2.0).unwrap(); + + // Create an empty stat to test + let stat = shifted_prior.empty_stat(); + assert_eq!(stat.shift(), 2.0); + + // Test posterior with empty data + let data: Vec = Vec::new(); + // Manually create DataOrSuffStat instead of using .into() + let dos = DataOrSuffStat::Data(&data); + let posterior = shifted_prior.posterior(&dos); + + // Shift should persist through posterior computation + assert_eq!(posterior.shift(), 2.0); + } + + #[test] + fn test_shifted_prior_with_data() { + let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); + let shifted_prior = ShiftedPrior::new(prior, 2.0).unwrap(); + + // Create some data - will be shifted by -2.0 internally for parent calculations + let data = vec![2.0, 4.0, 6.0]; + + // Manually create DataOrSuffStat instead of using .into() + let dos = DataOrSuffStat::Data(&data); + + // Compute posterior + let posterior = shifted_prior.posterior(&dos); + + // Shift should persist through posterior computation + assert_eq!(posterior.shift(), 2.0); + + // Verify ln_m and ln_pp work + let ln_m = shifted_prior.ln_m(&dos); + let ln_pp = shifted_prior.ln_pp(&2.0, &dos); + + // Values should be finite (actual values will depend on implementation) + assert!(ln_m.is_finite()); + assert!(ln_pp.is_finite()); + } +} \ No newline at end of file From 83919cfd87691877c540fcbdd9f38f843107f24f Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 21 Apr 2025 07:17:59 -0700 Subject: [PATCH 04/17] shifted --- src/dist/shifted.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/dist/shifted.rs b/src/dist/shifted.rs index 27d74c61..a74e1377 100644 --- a/src/dist/shifted.rs +++ b/src/dist/shifted.rs @@ -42,6 +42,10 @@ impl Shifted { pub fn new_unchecked(parent: D, shift: f64) -> Self { Shifted { parent, shift } } + + pub fn shift(&self) -> f64 { + self.shift + } } impl Sampleable for Shifted From 63e20fb693aa16cd4dbdc0010777f2a511773bfc Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 21 Apr 2025 07:19:07 -0700 Subject: [PATCH 05/17] cargo fmt --- src/dist/scaled_prior.rs | 66 ++++++++++++++++++++++++--------------- src/dist/shifted_prior.rs | 66 ++++++++++++++++++++++++--------------- 2 files changed, 80 insertions(+), 52 deletions(-) diff --git a/src/dist/scaled_prior.rs b/src/dist/scaled_prior.rs index c946b7e8..af028ac1 100644 --- a/src/dist/scaled_prior.rs +++ b/src/dist/scaled_prior.rs @@ -1,3 +1,4 @@ +use crate::data::{DataOrSuffStat, ScaledSuffStat}; use crate::dist::Scaled; use crate::traits::*; use rand::Rng; @@ -5,11 +6,10 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use std::fmt; use std::marker::PhantomData; -use crate::data::{DataOrSuffStat, ScaledSuffStat}; /// A wrapper for priors that scales the output distribution -/// -/// If drawing a `Pr` gives a distribution `Fx`, then drawing `ScaledPrior` +/// +/// If drawing a `Pr` gives a distribution `Fx`, then drawing `ScaledPrior` /// will produce a `Scaled`. #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] @@ -165,15 +165,21 @@ where ScaledSuffStat::new(parent_stat, self.scale) } - fn posterior(&self, x: &DataOrSuffStat>) -> Self::Posterior { + fn posterior( + &self, + x: &DataOrSuffStat>, + ) -> Self::Posterior { // For now, we'll just compute a new posterior with the same parameters // In the future, we should implement proper handling of the data let data: Vec = match x { - DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x * self.rate).collect(), + DataOrSuffStat::Data(xs) => { + xs.iter().map(|&x| x * self.rate).collect() + } DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now }; - - let posterior_parent = self.parent.posterior(&DataOrSuffStat::Data(&data)); + + let posterior_parent = + self.parent.posterior(&DataOrSuffStat::Data(&data)); Self::new_unchecked(posterior_parent, self.scale) } @@ -188,20 +194,28 @@ where ) -> f64 { // For now, we'll just compute from data let data: Vec = match x { - DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x * self.rate).collect(), + DataOrSuffStat::Data(xs) => { + xs.iter().map(|&x| x * self.rate).collect() + } DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now }; - - self.parent.ln_m_with_cache(cache, &DataOrSuffStat::Data(&data)) + + self.parent + .ln_m_with_cache(cache, &DataOrSuffStat::Data(&data)) } - fn ln_pp_cache(&self, x: &DataOrSuffStat>) -> Self::PpCache { + fn ln_pp_cache( + &self, + x: &DataOrSuffStat>, + ) -> Self::PpCache { // For now, we'll just compute from data let data: Vec = match x { - DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x * self.rate).collect(), + DataOrSuffStat::Data(xs) => { + xs.iter().map(|&x| x * self.rate).collect() + } DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now }; - + self.parent.ln_pp_cache(&DataOrSuffStat::Data(&data)) } @@ -227,10 +241,10 @@ mod tests { fn test_scaled_prior_draw() { let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); let scaled_prior = ScaledPrior::new(prior, 2.0).unwrap(); - + let mut rng = Xoshiro256Plus::seed_from_u64(42); let dist = scaled_prior.draw(&mut rng); - + assert_eq!(dist.scale(), 2.0); } @@ -238,44 +252,44 @@ mod tests { fn test_scaled_prior_conjugate() { let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); let scaled_prior = ScaledPrior::new(prior, 2.0).unwrap(); - + // Create an empty stat to test let stat = scaled_prior.empty_stat(); assert_eq!(stat.scale(), 2.0); - + // Test posterior with empty data let data: Vec = Vec::new(); // Manually create DataOrSuffStat instead of using .into() let dos = DataOrSuffStat::Data(&data); let posterior = scaled_prior.posterior(&dos); - + // Scale should persist through posterior computation assert_eq!(posterior.scale(), 2.0); } - + #[test] fn test_scaled_prior_with_data() { let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); let scaled_prior = ScaledPrior::new(prior, 2.0).unwrap(); - + // Create some data - will be scaled by 1/2 internally for parent calculations let data = vec![2.0, 4.0, 6.0]; - + // Manually create DataOrSuffStat instead of using .into() let dos = DataOrSuffStat::Data(&data); - + // Compute posterior let posterior = scaled_prior.posterior(&dos); - + // Scale should persist through posterior computation assert_eq!(posterior.scale(), 2.0); - + // Verify ln_m and ln_pp work let ln_m = scaled_prior.ln_m(&dos); let ln_pp = scaled_prior.ln_pp(&2.0, &dos); - + // Values should be finite (actual values will depend on implementation) assert!(ln_m.is_finite()); assert!(ln_pp.is_finite()); } -} \ No newline at end of file +} diff --git a/src/dist/shifted_prior.rs b/src/dist/shifted_prior.rs index a0123c1e..d575a912 100644 --- a/src/dist/shifted_prior.rs +++ b/src/dist/shifted_prior.rs @@ -1,3 +1,4 @@ +use crate::data::{DataOrSuffStat, ShiftedSuffStat}; use crate::dist::Shifted; use crate::traits::*; use rand::Rng; @@ -5,11 +6,10 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use std::fmt; use std::marker::PhantomData; -use crate::data::{DataOrSuffStat, ShiftedSuffStat}; /// A wrapper for priors that shifts the output distribution -/// -/// If drawing a `Pr` gives a distribution `Fx`, then drawing `ShiftedPrior` +/// +/// If drawing a `Pr` gives a distribution `Fx`, then drawing `ShiftedPrior` /// will produce a `Shifted`. #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] @@ -141,15 +141,21 @@ where ShiftedSuffStat::new(parent_stat, self.shift) } - fn posterior(&self, x: &DataOrSuffStat>) -> Self::Posterior { + fn posterior( + &self, + x: &DataOrSuffStat>, + ) -> Self::Posterior { // For now, we'll just compute a new posterior with the same parameters // In the future, we should implement proper handling of the data let data: Vec = match x { - DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x - self.shift).collect(), + DataOrSuffStat::Data(xs) => { + xs.iter().map(|&x| x - self.shift).collect() + } DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now }; - - let posterior_parent = self.parent.posterior(&DataOrSuffStat::Data(&data)); + + let posterior_parent = + self.parent.posterior(&DataOrSuffStat::Data(&data)); Self::new_unchecked(posterior_parent, self.shift) } @@ -164,20 +170,28 @@ where ) -> f64 { // For now, we'll just compute from data let data: Vec = match x { - DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x - self.shift).collect(), + DataOrSuffStat::Data(xs) => { + xs.iter().map(|&x| x - self.shift).collect() + } DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now }; - - self.parent.ln_m_with_cache(cache, &DataOrSuffStat::Data(&data)) + + self.parent + .ln_m_with_cache(cache, &DataOrSuffStat::Data(&data)) } - fn ln_pp_cache(&self, x: &DataOrSuffStat>) -> Self::PpCache { + fn ln_pp_cache( + &self, + x: &DataOrSuffStat>, + ) -> Self::PpCache { // For now, we'll just compute from data let data: Vec = match x { - DataOrSuffStat::Data(xs) => xs.iter().map(|&x| x - self.shift).collect(), + DataOrSuffStat::Data(xs) => { + xs.iter().map(|&x| x - self.shift).collect() + } DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now }; - + self.parent.ln_pp_cache(&DataOrSuffStat::Data(&data)) } @@ -202,10 +216,10 @@ mod tests { fn test_shifted_prior_draw() { let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); let shifted_prior = ShiftedPrior::new(prior, 2.0).unwrap(); - + let mut rng = Xoshiro256Plus::seed_from_u64(42); let dist = shifted_prior.draw(&mut rng); - + assert_eq!(dist.shift(), 2.0); } @@ -213,44 +227,44 @@ mod tests { fn test_shifted_prior_conjugate() { let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); let shifted_prior = ShiftedPrior::new(prior, 2.0).unwrap(); - + // Create an empty stat to test let stat = shifted_prior.empty_stat(); assert_eq!(stat.shift(), 2.0); - + // Test posterior with empty data let data: Vec = Vec::new(); // Manually create DataOrSuffStat instead of using .into() let dos = DataOrSuffStat::Data(&data); let posterior = shifted_prior.posterior(&dos); - + // Shift should persist through posterior computation assert_eq!(posterior.shift(), 2.0); } - + #[test] fn test_shifted_prior_with_data() { let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0); let shifted_prior = ShiftedPrior::new(prior, 2.0).unwrap(); - + // Create some data - will be shifted by -2.0 internally for parent calculations let data = vec![2.0, 4.0, 6.0]; - + // Manually create DataOrSuffStat instead of using .into() let dos = DataOrSuffStat::Data(&data); - + // Compute posterior let posterior = shifted_prior.posterior(&dos); - + // Shift should persist through posterior computation assert_eq!(posterior.shift(), 2.0); - + // Verify ln_m and ln_pp work let ln_m = shifted_prior.ln_m(&dos); let ln_pp = shifted_prior.ln_pp(&2.0, &dos); - + // Values should be finite (actual values will depend on implementation) assert!(ln_m.is_finite()); assert!(ln_pp.is_finite()); } -} \ No newline at end of file +} From 78611e16b60233d3d4f989bcddbde0c19b3f6a01 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 21 Apr 2025 08:32:56 -0700 Subject: [PATCH 06/17] udpate --- src/dist/scaled_prior.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dist/scaled_prior.rs b/src/dist/scaled_prior.rs index af028ac1..d5c3fe3b 100644 --- a/src/dist/scaled_prior.rs +++ b/src/dist/scaled_prior.rs @@ -156,7 +156,7 @@ where Fx: HasSuffStat + Scalable + HasDensity, Scaled: HasSuffStat>, { - type Posterior = Self; + type Posterior = ScaledPrior; type MCache = Pr::MCache; type PpCache = Pr::PpCache; From 06f3eac7a559fe5d1b317ce86f462e96a6797cbb Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 21 Apr 2025 10:51:34 -0700 Subject: [PATCH 07/17] work --- src/data/mod.rs | 2 +- src/dist/scaled_prior.rs | 19 +++++++------------ src/traits.rs | 5 ++++- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/data/mod.rs b/src/data/mod.rs index a877fa6d..fc8757af 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -235,7 +235,7 @@ pub fn extract_stat_then( ) -> Y where Fx: HasSuffStat + HasDensity, - Pr: ConjugatePrior, + Pr: ConjugatePrior + ?Sized, Fnx: Fn(&Fx::Stat) -> Y, { match x { diff --git a/src/dist/scaled_prior.rs b/src/dist/scaled_prior.rs index d5c3fe3b..5a5ab086 100644 --- a/src/dist/scaled_prior.rs +++ b/src/dist/scaled_prior.rs @@ -1,3 +1,4 @@ +use crate::data::extract_stat_then; use crate::data::{DataOrSuffStat, ScaledSuffStat}; use crate::dist::Scaled; use crate::traits::*; @@ -169,18 +170,12 @@ where &self, x: &DataOrSuffStat>, ) -> Self::Posterior { - // For now, we'll just compute a new posterior with the same parameters - // In the future, we should implement proper handling of the data - let data: Vec = match x { - DataOrSuffStat::Data(xs) => { - xs.iter().map(|&x| x * self.rate).collect() - } - DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now - }; - - let posterior_parent = - self.parent.posterior(&DataOrSuffStat::Data(&data)); - Self::new_unchecked(posterior_parent, self.scale) + extract_stat_then(self, x, |stat: &ScaledSuffStat| { + ScaledPrior::new_unchecked( + self.parent.posterior_from_suffstat(&stat.parent()), + self.scale, + ) + }) } fn ln_m_cache(&self) -> Self::MCache { diff --git a/src/traits.rs b/src/traits.rs index 1d4506e1..16e7d4b1 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,4 +1,5 @@ //! Trait definitions +use crate::data::extract_stat_then; pub use crate::data::DataOrSuffStat; use rand::Rng; @@ -558,7 +559,9 @@ where self.posterior(&DataOrSuffStat::SuffStat(stat)) } - fn posterior(&self, x: &DataOrSuffStat) -> Self::Posterior; + fn posterior(&self, x: &DataOrSuffStat) -> Self::Posterior { + extract_stat_then(self, x, |stat| self.posterior_from_suffstat(stat)) + } /// Compute the cache for the log marginal likelihood. fn ln_m_cache(&self) -> Self::MCache; From f1264a0fb010fbb06299ee1a7b7d552077c5c34f Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 21 Apr 2025 10:54:04 -0700 Subject: [PATCH 08/17] more --- src/dist/scaled_prior.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/dist/scaled_prior.rs b/src/dist/scaled_prior.rs index 5a5ab086..32026bc3 100644 --- a/src/dist/scaled_prior.rs +++ b/src/dist/scaled_prior.rs @@ -166,16 +166,14 @@ where ScaledSuffStat::new(parent_stat, self.scale) } - fn posterior( + fn posterior_from_suffstat( &self, - x: &DataOrSuffStat>, + stat: &ScaledSuffStat, ) -> Self::Posterior { - extract_stat_then(self, x, |stat: &ScaledSuffStat| { - ScaledPrior::new_unchecked( - self.parent.posterior_from_suffstat(&stat.parent()), - self.scale, - ) - }) + ScaledPrior::new_unchecked( + self.parent.posterior_from_suffstat(&stat.parent()), + self.scale, + ) } fn ln_m_cache(&self) -> Self::MCache { From 937cd2adf3086ae4a9db0e846f2641cc4aceb400 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 21 Apr 2025 11:04:05 -0700 Subject: [PATCH 09/17] extract_stat rewrites --- src/dist/normal_gamma/gaussian_prior.rs | 30 +++++++++---------- .../normal_inv_chi_squared/gaussian_prior.rs | 23 +++++++------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/src/dist/normal_gamma/gaussian_prior.rs b/src/dist/normal_gamma/gaussian_prior.rs index d5870baf..6dbaa2ea 100644 --- a/src/dist/normal_gamma/gaussian_prior.rs +++ b/src/dist/normal_gamma/gaussian_prior.rs @@ -3,7 +3,7 @@ use std::f64::consts::LN_2; use super::dos_to_post; use crate::consts::*; -use crate::data::{extract_stat, GaussianSuffStat}; +use crate::data::{extract_stat_then, GaussianSuffStat}; use crate::dist::{Gaussian, NormalGamma}; use crate::gaussian_prior_geweke_testable; use crate::misc::ln_gammafn; @@ -77,20 +77,20 @@ impl ConjugatePrior for NormalGamma { } fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { - let stat = extract_stat(self, x); - - let params = posterior_from_stat(self, &stat); - let PosteriorParameters { r, s, v, .. } = params; - - let half_v = v / 2.0; - let g_ratio = ln_gammafn(half_v + 0.5) - ln_gammafn(half_v); - let term = 0.5_f64.mul_add(LN_2, -HALF_LN_2PI) - + 0.5_f64.mul_add( - (r / (r + 1_f64)).ln(), - half_v.mul_add(s.ln(), g_ratio), - ); - - (params, term) + extract_stat_then(self, x, |stat| { + let params = posterior_from_stat(self, &stat); + let PosteriorParameters { r, s, v, .. } = params; + + let half_v = v / 2.0; + let g_ratio = ln_gammafn(half_v + 0.5) - ln_gammafn(half_v); + let term = 0.5_f64.mul_add(LN_2, -HALF_LN_2PI) + + 0.5_f64.mul_add( + (r / (r + 1_f64)).ln(), + half_v.mul_add(s.ln(), g_ratio), + ); + + (params, term) + }) } fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { diff --git a/src/dist/normal_inv_chi_squared/gaussian_prior.rs b/src/dist/normal_inv_chi_squared/gaussian_prior.rs index 39ac2cbc..d23f5f02 100644 --- a/src/dist/normal_inv_chi_squared/gaussian_prior.rs +++ b/src/dist/normal_inv_chi_squared/gaussian_prior.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use std::f64::consts::PI; use crate::consts::HALF_LN_PI; -use crate::data::{extract_stat, extract_stat_then, GaussianSuffStat}; +use crate::data::{ extract_stat_then, GaussianSuffStat}; use crate::dist::{Gaussian, NormalInvChiSquared}; use crate::gaussian_prior_geweke_testable; use crate::misc::ln_gammafn; @@ -94,16 +94,17 @@ impl ConjugatePrior for NormalInvChiSquared { } fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { - let stat = extract_stat(self, x); - let post = posterior_from_stat(self, &stat); - let kn = post.kn; - let vn = post.vn; - - let z = 0.5_f64.mul_add( - (kn / ((kn + 1.0) * PI * vn * post.s2n)).ln(), - ln_gammafn((vn + 1.0) / 2.0) - ln_gammafn(vn / 2.0), - ); - (post, z) + extract_stat_then(self, x, |stat: &GaussianSuffStat| { + let post = posterior_from_stat(self, &stat); + let kn = post.kn; + let vn = post.vn; + + let z = 0.5_f64.mul_add( + (kn / ((kn + 1.0) * PI * vn * post.s2n)).ln(), + ln_gammafn((vn + 1.0) / 2.0) - ln_gammafn(vn / 2.0), + ); + (post, z) + }) } fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { From 2c2bb83e3a63135f69cd918a60d0ca4088a6b878 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 21 Apr 2025 11:04:16 -0700 Subject: [PATCH 10/17] drop unused import --- src/dist/scaled_prior.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dist/scaled_prior.rs b/src/dist/scaled_prior.rs index 32026bc3..7d1af831 100644 --- a/src/dist/scaled_prior.rs +++ b/src/dist/scaled_prior.rs @@ -1,4 +1,3 @@ -use crate::data::extract_stat_then; use crate::data::{DataOrSuffStat, ScaledSuffStat}; use crate::dist::Scaled; use crate::traits::*; From 4a2cdcfbe52f83e9574f7022fa858e5a0f798014 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 21 Apr 2025 19:00:51 -0700 Subject: [PATCH 11/17] simplify --- .../normal_inv_chi_squared/gaussian_prior.rs | 11 ++++++----- .../stick_breaking_process/stick_breaking.rs | 16 ---------------- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/src/dist/normal_inv_chi_squared/gaussian_prior.rs b/src/dist/normal_inv_chi_squared/gaussian_prior.rs index d23f5f02..dfa58f07 100644 --- a/src/dist/normal_inv_chi_squared/gaussian_prior.rs +++ b/src/dist/normal_inv_chi_squared/gaussian_prior.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use std::f64::consts::PI; use crate::consts::HALF_LN_PI; -use crate::data::{ extract_stat_then, GaussianSuffStat}; +use crate::data::{extract_stat_then, GaussianSuffStat}; use crate::dist::{Gaussian, NormalInvChiSquared}; use crate::gaussian_prior_geweke_testable; use crate::misc::ln_gammafn; @@ -69,10 +69,11 @@ impl ConjugatePrior for NormalInvChiSquared { GaussianSuffStat::new() } - fn posterior(&self, x: &DataOrSuffStat) -> Self { - extract_stat_then(self, x, |stat: &GaussianSuffStat| { - posterior_from_stat(self, &stat).into() - }) + fn posterior_from_suffstat( + &self, + stat: &GaussianSuffStat, + ) -> Self::Posterior { + posterior_from_stat(self, stat).into() } #[inline] diff --git a/src/experimental/stick_breaking_process/stick_breaking.rs b/src/experimental/stick_breaking_process/stick_breaking.rs index 731e2697..0d7725f7 100644 --- a/src/experimental/stick_breaking_process/stick_breaking.rs +++ b/src/experimental/stick_breaking_process/stick_breaking.rs @@ -288,22 +288,6 @@ impl ConjugatePrior for StickBreaking { } } - fn posterior( - &self, - x: &DataOrSuffStat, - ) -> Self::Posterior { - match x { - DataOrSuffStat::Data(xs) => { - let mut stat = StickBreakingDiscreteSuffStat::new(); - stat.observe_many(xs); - self.posterior_from_suffstat(&stat) - } - DataOrSuffStat::SuffStat(stat) => { - self.posterior_from_suffstat(stat) - } - } - } - /// Computes the logarithm of the marginal likelihood. fn ln_m(&self, x: &DataOrSuffStat) -> f64 { let count_pairs = match x { From cc5df2ccca089ab61e1db60bd5f5854e69ecc3cd Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 22 Apr 2025 07:11:21 -0700 Subject: [PATCH 12/17] progress --- src/data/mod.rs | 1 - src/dist/scaled_prior.rs | 17 ++++++----------- src/dist/shifted_prior.rs | 3 +-- src/dist/vonmises.rs | 2 +- src/test.rs | 27 +++++++++++++++++---------- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/data/mod.rs b/src/data/mod.rs index fc8757af..a9e738eb 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -578,7 +578,6 @@ mod tests { &self, _x: &crate::data::GaussianData, ) -> Self::PpCache { - () } fn ln_pp_with_cache( diff --git a/src/dist/scaled_prior.rs b/src/dist/scaled_prior.rs index 7d1af831..4d9d0757 100644 --- a/src/dist/scaled_prior.rs +++ b/src/dist/scaled_prior.rs @@ -1,3 +1,4 @@ +use crate::data::extract_stat_then; use crate::data::{DataOrSuffStat, ScaledSuffStat}; use crate::dist::Scaled; use crate::traits::*; @@ -200,15 +201,10 @@ where &self, x: &DataOrSuffStat>, ) -> Self::PpCache { - // For now, we'll just compute from data - let data: Vec = match x { - DataOrSuffStat::Data(xs) => { - xs.iter().map(|&x| x * self.rate).collect() - } - DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now - }; - - self.parent.ln_pp_cache(&DataOrSuffStat::Data(&data)) + extract_stat_then(self, x, |stat| { + self.parent + .ln_pp_cache(&DataOrSuffStat::SuffStat(stat.parent())) + }) } fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { @@ -224,8 +220,7 @@ where mod tests { use super::*; use crate::data::DataOrSuffStat; - use crate::dist::{Gaussian, NormalInvChiSquared, Scaled}; - use crate::traits::*; + use crate::dist::{NormalInvChiSquared}; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; diff --git a/src/dist/shifted_prior.rs b/src/dist/shifted_prior.rs index d575a912..f96bc815 100644 --- a/src/dist/shifted_prior.rs +++ b/src/dist/shifted_prior.rs @@ -207,8 +207,7 @@ where mod tests { use super::*; use crate::data::DataOrSuffStat; - use crate::dist::{Gaussian, NormalInvChiSquared, Shifted}; - use crate::traits::*; + use crate::dist::{NormalInvChiSquared}; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; diff --git a/src/dist/vonmises.rs b/src/dist/vonmises.rs index 6ac6ded0..f3f80b0b 100644 --- a/src/dist/vonmises.rs +++ b/src/dist/vonmises.rs @@ -641,7 +641,7 @@ mod tests { #[test] fn slice_step_vs_draw_test() { - let n_samples = 1000000; + let n_samples = 1_000_000; let mut rng = rand::thread_rng(); let mu = 1.5; let k = 2.0; diff --git a/src/test.rs b/src/test.rs index 776ce440..8ef8a604 100644 --- a/src/test.rs +++ b/src/test.rs @@ -10,8 +10,8 @@ macro_rules! test_serde_params { ($fx: expr, $fx_ty: ty, $x_ty: ty) => { #[test] fn test_serde_ln_f() { - use ::serde::Deserialize; - use ::serde::Serialize; + // use ::serde::Deserialize; + // use ::serde::Serialize; use $crate::traits::HasDensity; use $crate::traits::Sampleable; @@ -636,17 +636,17 @@ where Ok(p_value) } -mod tests { - use crate::prelude::Exponential; - use crate::prelude::Gaussian; - use crate::test::density_histogram_test; - use crate::traits::HasDensity; - use crate::traits::Sampleable; - use rand::SeedableRng; - use rand_xoshiro::Xoshiro256Plus; +mod tests { #[test] fn test_density_histogram_gaussian() { + use crate::prelude::Gaussian; + use crate::test::density_histogram_test; + use crate::traits::HasDensity; + use crate::traits::Sampleable; + use rand_xoshiro::Xoshiro256Plus; + use rand::SeedableRng; + let mut rng = Xoshiro256Plus::seed_from_u64(1); let dist = Gaussian::default(); @@ -670,6 +670,13 @@ mod tests { #[test] fn test_density_histogram_exponential() { + use crate::prelude::Exponential; + use crate::traits::Sampleable; + use crate::traits::HasDensity; + use crate::test::density_histogram_test; + use rand_xoshiro::Xoshiro256Plus; + use rand::SeedableRng; + let mut rng = Xoshiro256Plus::seed_from_u64(1); let dist = Exponential::default(); From 02654d4d8156a7fdef080898d486787466f5b0fc Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 22 Apr 2025 07:16:12 -0700 Subject: [PATCH 13/17] shifted --- src/dist/shifted_prior.rs | 37 ++++++++++++++++--------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/src/dist/shifted_prior.rs b/src/dist/shifted_prior.rs index f96bc815..42611e74 100644 --- a/src/dist/shifted_prior.rs +++ b/src/dist/shifted_prior.rs @@ -1,4 +1,5 @@ use crate::data::{DataOrSuffStat, ShiftedSuffStat}; +use crate::data::extract_stat_then; use crate::dist::Shifted; use crate::traits::*; use rand::Rng; @@ -141,22 +142,21 @@ where ShiftedSuffStat::new(parent_stat, self.shift) } + fn posterior_from_suffstat( + &self, + stat: &ShiftedSuffStat, + ) -> Self::Posterior { + ShiftedPrior::new_unchecked( + self.parent.posterior_from_suffstat(&stat.parent()), + self.shift, + ) + } + fn posterior( &self, x: &DataOrSuffStat>, ) -> Self::Posterior { - // For now, we'll just compute a new posterior with the same parameters - // In the future, we should implement proper handling of the data - let data: Vec = match x { - DataOrSuffStat::Data(xs) => { - xs.iter().map(|&x| x - self.shift).collect() - } - DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now - }; - - let posterior_parent = - self.parent.posterior(&DataOrSuffStat::Data(&data)); - Self::new_unchecked(posterior_parent, self.shift) + extract_stat_then(self, x, |stat| self.posterior_from_suffstat(stat)) } fn ln_m_cache(&self) -> Self::MCache { @@ -184,15 +184,10 @@ where &self, x: &DataOrSuffStat>, ) -> Self::PpCache { - // For now, we'll just compute from data - let data: Vec = match x { - DataOrSuffStat::Data(xs) => { - xs.iter().map(|&x| x - self.shift).collect() - } - DataOrSuffStat::SuffStat(_) => vec![], // Not handling suffstat for now - }; - - self.parent.ln_pp_cache(&DataOrSuffStat::Data(&data)) + extract_stat_then(self, x, |stat| { + self.parent + .ln_pp_cache(&DataOrSuffStat::SuffStat(stat.parent())) + }) } fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { From fd7b2addecc05ef66a1c45a98a236e5792fabe53 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 22 Apr 2025 07:19:19 -0700 Subject: [PATCH 14/17] cargo fmt --- src/dist/scaled_prior.rs | 2 +- src/dist/shifted_prior.rs | 6 +++--- src/test.rs | 12 ++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/dist/scaled_prior.rs b/src/dist/scaled_prior.rs index 4d9d0757..082f8e97 100644 --- a/src/dist/scaled_prior.rs +++ b/src/dist/scaled_prior.rs @@ -220,7 +220,7 @@ where mod tests { use super::*; use crate::data::DataOrSuffStat; - use crate::dist::{NormalInvChiSquared}; + use crate::dist::NormalInvChiSquared; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; diff --git a/src/dist/shifted_prior.rs b/src/dist/shifted_prior.rs index 42611e74..75b3eefe 100644 --- a/src/dist/shifted_prior.rs +++ b/src/dist/shifted_prior.rs @@ -1,5 +1,5 @@ -use crate::data::{DataOrSuffStat, ShiftedSuffStat}; use crate::data::extract_stat_then; +use crate::data::{DataOrSuffStat, ShiftedSuffStat}; use crate::dist::Shifted; use crate::traits::*; use rand::Rng; @@ -151,7 +151,7 @@ where self.shift, ) } - + fn posterior( &self, x: &DataOrSuffStat>, @@ -202,7 +202,7 @@ where mod tests { use super::*; use crate::data::DataOrSuffStat; - use crate::dist::{NormalInvChiSquared}; + use crate::dist::NormalInvChiSquared; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; diff --git a/src/test.rs b/src/test.rs index 8ef8a604..9cd666e1 100644 --- a/src/test.rs +++ b/src/test.rs @@ -636,16 +636,16 @@ where Ok(p_value) } -mod tests { +mod tests { #[test] fn test_density_histogram_gaussian() { use crate::prelude::Gaussian; - use crate::test::density_histogram_test; + use crate::test::density_histogram_test; use crate::traits::HasDensity; use crate::traits::Sampleable; - use rand_xoshiro::Xoshiro256Plus; use rand::SeedableRng; + use rand_xoshiro::Xoshiro256Plus; let mut rng = Xoshiro256Plus::seed_from_u64(1); let dist = Gaussian::default(); @@ -671,11 +671,11 @@ mod tests { #[test] fn test_density_histogram_exponential() { use crate::prelude::Exponential; - use crate::traits::Sampleable; + use crate::test::density_histogram_test; use crate::traits::HasDensity; - use crate::test::density_histogram_test; - use rand_xoshiro::Xoshiro256Plus; + use crate::traits::Sampleable; use rand::SeedableRng; + use rand_xoshiro::Xoshiro256Plus; let mut rng = Xoshiro256Plus::seed_from_u64(1); let dist = Exponential::default(); From 9671b6a936140e4f569765c028a435bb295d3057 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 22 Apr 2025 07:21:04 -0700 Subject: [PATCH 15/17] drop old comments --- src/test.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/test.rs b/src/test.rs index 9cd666e1..463dcc59 100644 --- a/src/test.rs +++ b/src/test.rs @@ -10,8 +10,6 @@ macro_rules! test_serde_params { ($fx: expr, $fx_ty: ty, $x_ty: ty) => { #[test] fn test_serde_ln_f() { - // use ::serde::Deserialize; - // use ::serde::Serialize; use $crate::traits::HasDensity; use $crate::traits::Sampleable; From 1b9458b9963a6f275df960baead11710008628a4 Mon Sep 17 00:00:00 2001 From: Mike Schmidt Date: Tue, 22 Apr 2025 09:35:10 -0600 Subject: [PATCH 16/17] style: Adjusted style for clippy. --- benches/vonmises.rs | 3 --- examples/betaprime_sbc.rs | 3 --- examples/sbd.rs | 3 --- examples/stickbreaking_posterior.rs | 4 ---- src/data/mod.rs | 4 ++-- src/dist/betaprime.rs | 7 +++---- src/dist/cdvm.rs | 11 +++++++---- src/dist/normal_gamma/gaussian_prior.rs | 2 +- .../normal_inv_chi_squared/gaussian_prior.rs | 4 ++-- src/dist/scaled_prior.rs | 2 +- src/dist/shifted_prior.rs | 2 +- src/dist/vonmises.rs | 16 +++++++++------- src/misc/bessel.rs | 7 +++++-- src/misc/func.rs | 1 + src/test.rs | 5 ++--- 15 files changed, 34 insertions(+), 40 deletions(-) diff --git a/benches/vonmises.rs b/benches/vonmises.rs index e40d11da..bd56d1dc 100644 --- a/benches/vonmises.rs +++ b/benches/vonmises.rs @@ -1,17 +1,14 @@ use criterion::black_box; -use criterion::measurement::WallTime; use criterion::AxisScale; use criterion::BatchSize; use criterion::BenchmarkId; use criterion::Criterion; use criterion::PlotConfiguration; -use criterion::Throughput; use criterion::{criterion_group, criterion_main}; use rand::Rng; use rv::dist::VonMises; use rv::misc::bessel::log_i0; use rv::prelude::*; -use rv::traits::*; use std::f64::consts::PI; fn bench_vm_draw(c: &mut Criterion) { diff --git a/examples/betaprime_sbc.rs b/examples/betaprime_sbc.rs index b7df6ec1..fda5a39b 100644 --- a/examples/betaprime_sbc.rs +++ b/examples/betaprime_sbc.rs @@ -1,5 +1,3 @@ -use rv::dist::BetaPrime; - #[cfg(feature = "experimental")] use rand::SeedableRng; #[cfg(feature = "experimental")] @@ -8,7 +6,6 @@ use rand_xoshiro::Xoshiro256Plus; use rv::experimental::stick_breaking_process::{ StickBreaking, StickBreakingDiscrete, StickBreakingDiscreteSuffStat, }; -use rv::prelude::*; // Simulation-based calibration // For details see http://www.stat.columbia.edu/~gelman/research/unpublished/sbc.pdf diff --git a/examples/sbd.rs b/examples/sbd.rs index 224018cf..1161afe3 100644 --- a/examples/sbd.rs +++ b/examples/sbd.rs @@ -1,6 +1,3 @@ -use rand::SeedableRng; -use rv::prelude::*; - #[cfg(feature = "experimental")] use rv::experimental::stick_breaking_process::{ StickBreaking, StickBreakingDiscrete, StickSequence, diff --git a/examples/stickbreaking_posterior.rs b/examples/stickbreaking_posterior.rs index d72eb436..51dc3ebb 100644 --- a/examples/stickbreaking_posterior.rs +++ b/examples/stickbreaking_posterior.rs @@ -1,7 +1,3 @@ -use itertools::Either; -use peroxide::statistics::stat::Statistics; -use rv::prelude::*; - #[cfg(feature = "experimental")] use rv::experimental::stick_breaking_process::*; diff --git a/src/data/mod.rs b/src/data/mod.rs index a9e738eb..9d2b1015 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -386,8 +386,8 @@ mod tests { fn impl_bool_into_bool() { let t = true; let f = false; - assert_eq!(t.into_bool(), true); - assert_eq!(f.into_bool(), false); + assert!(t.into_bool()); + assert!(!f.into_bool()); } #[test] diff --git a/src/dist/betaprime.rs b/src/dist/betaprime.rs index 04307a9f..be2ba15a 100644 --- a/src/dist/betaprime.rs +++ b/src/dist/betaprime.rs @@ -10,6 +10,9 @@ use std::f64; use std::fmt; use std::sync::OnceLock; +#[cfg(feature = "experimental")] +use super::UnitPowerLaw; + /// [Beta prime distribution](https://en.wikipedia.org/wiki/Beta_prime_distribution), /// BetaPrime(α, β) over x in (0, ∞). /// @@ -277,18 +280,14 @@ impl Sampleable for BetaPrime { } } -use crate::data::DataOrSuffStat; #[cfg(feature = "experimental")] use crate::experimental::stick_breaking_process::{ StickBreakingDiscrete, StickBreakingDiscreteSuffStat, }; -use crate::traits::ConjugatePrior; #[cfg(feature = "experimental")] use crate::experimental::stick_breaking_process::StickBreaking; -use crate::prelude::UnitPowerLaw; - #[cfg(feature = "experimental")] impl Sampleable for BetaPrime { fn draw(&self, rng: &mut R) -> StickBreakingDiscrete { diff --git a/src/dist/cdvm.rs b/src/dist/cdvm.rs index d3fab0df..73632ea6 100644 --- a/src/dist/cdvm.rs +++ b/src/dist/cdvm.rs @@ -243,10 +243,13 @@ impl HasSuffStat for Cdvm { // TODO: Should we cache twopimu_over_m.cos() and twopimu_over_m.sin()? let (sin_twopimu_over_m, cos_twopimu_over_m) = twopimu_over_m.sin_cos(); - self.kappa - * (stat.sum_cos() * cos_twopimu_over_m - + stat.sum_sin() * sin_twopimu_over_m) - - stat.n() as f64 * self.log_norm_const() + self.kappa.mul_add( + stat.sum_cos().mul_add( + cos_twopimu_over_m, + stat.sum_sin() * sin_twopimu_over_m, + ), + -(stat.n() as f64 * self.log_norm_const()), + ) } } diff --git a/src/dist/normal_gamma/gaussian_prior.rs b/src/dist/normal_gamma/gaussian_prior.rs index 6dbaa2ea..4060c918 100644 --- a/src/dist/normal_gamma/gaussian_prior.rs +++ b/src/dist/normal_gamma/gaussian_prior.rs @@ -78,7 +78,7 @@ impl ConjugatePrior for NormalGamma { fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { extract_stat_then(self, x, |stat| { - let params = posterior_from_stat(self, &stat); + let params = posterior_from_stat(self, stat); let PosteriorParameters { r, s, v, .. } = params; let half_v = v / 2.0; diff --git a/src/dist/normal_inv_chi_squared/gaussian_prior.rs b/src/dist/normal_inv_chi_squared/gaussian_prior.rs index dfa58f07..3dce9924 100644 --- a/src/dist/normal_inv_chi_squared/gaussian_prior.rs +++ b/src/dist/normal_inv_chi_squared/gaussian_prior.rs @@ -88,7 +88,7 @@ impl ConjugatePrior for NormalInvChiSquared { ) -> f64 { extract_stat_then(self, x, |stat: &GaussianSuffStat| { let n = stat.n() as f64; - let post: Self = posterior_from_stat(self, &stat).into(); + let post: Self = posterior_from_stat(self, stat).into(); let lnz_n = post.ln_z(); n.mul_add(-HALF_LN_PI, lnz_n - cache) }) @@ -96,7 +96,7 @@ impl ConjugatePrior for NormalInvChiSquared { fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { extract_stat_then(self, x, |stat: &GaussianSuffStat| { - let post = posterior_from_stat(self, &stat); + let post = posterior_from_stat(self, stat); let kn = post.kn; let vn = post.vn; diff --git a/src/dist/scaled_prior.rs b/src/dist/scaled_prior.rs index 082f8e97..7b5c4a86 100644 --- a/src/dist/scaled_prior.rs +++ b/src/dist/scaled_prior.rs @@ -171,7 +171,7 @@ where stat: &ScaledSuffStat, ) -> Self::Posterior { ScaledPrior::new_unchecked( - self.parent.posterior_from_suffstat(&stat.parent()), + self.parent.posterior_from_suffstat(stat.parent()), self.scale, ) } diff --git a/src/dist/shifted_prior.rs b/src/dist/shifted_prior.rs index 75b3eefe..8b9ebd3b 100644 --- a/src/dist/shifted_prior.rs +++ b/src/dist/shifted_prior.rs @@ -147,7 +147,7 @@ where stat: &ShiftedSuffStat, ) -> Self::Posterior { ShiftedPrior::new_unchecked( - self.parent.posterior_from_suffstat(&stat.parent()), + self.parent.posterior_from_suffstat(stat.parent()), self.shift, ) } diff --git a/src/dist/vonmises.rs b/src/dist/vonmises.rs index f3f80b0b..46d325f3 100644 --- a/src/dist/vonmises.rs +++ b/src/dist/vonmises.rs @@ -312,13 +312,13 @@ impl VonMises { #[inline] pub fn slice_step(x: f64, mu: f64, k: f64, rng: &mut R) -> f64 { // y ~ Uniform(0, exp(k * cos(x - μ))) - let logy = rng.gen::().ln() + k * (x - mu).cos(); + let logy = k.mul_add((x - mu).cos(), rng.gen::().ln()); // Need to solve for x in k cos(x) = logy // If logy < -k, then we're below the cos curve // In that case, sample uniformly on the circle let xmax = if logy < -k { PI } else { (logy / k).acos() }; // Sample uniformly on [-xmax, xmax] and add μ - let x = xmax * rng.gen_range(-1.0..=1.0) + mu; + let x = xmax.mul_add(rng.gen_range(-1.0..=1.0), mu); // Ensure result is in [0, 2π) x.rem_euclid(2.0 * PI) } @@ -378,7 +378,7 @@ macro_rules! impl_traits { } } let z = (1.0 - t) / (1.0 + t); - f = (1.0 + r * z) / (r + z); + f = r.mul_add(z, 1.0) / (r + z); let c = self.k * (r - f); if (c.mul_add(2.0 - c, -u) > 0.0) || ((c / u).ln() + 1.0 - c >= 0.0) @@ -462,9 +462,11 @@ impl HasSuffStat for VonMises { } fn ln_f_stat(&self, stat: &Self::Stat) -> f64 { - self.k - * (stat.sum_cos() * self.cos_mu() + stat.sum_sin() * self.sin_mu()) - - stat.n() as f64 * (self.log_i0_k() + LN_2PI) + self.k.mul_add( + stat.sum_cos() + .mul_add(self.cos_mu(), stat.sum_sin() * self.sin_mu()), + -(stat.n() as f64 * (self.log_i0_k() + LN_2PI)), + ) } } @@ -734,7 +736,7 @@ mod tests { let sample: f64 = vm.draw(&mut rng); prop_assert!( - sample >= 0.0 && sample < 2.0 * std::f64::consts::PI, + (0.0..2.0 * std::f64::consts::PI).contains(&sample), "Sample {} not in range [0, 2π) for VonMises({}, {})", sample, mu, k ); diff --git a/src/misc/bessel.rs b/src/misc/bessel.rs index 26e4c0ae..b9651fa1 100644 --- a/src/misc/bessel.rs +++ b/src/misc/bessel.rs @@ -158,8 +158,11 @@ pub fn log_i0(x: f64) -> f64 { let y = ax.mul_add(0.5, -2.0); ax + chbevl(y, &BESSI0_COEFFS_A).ln() } else { - ax + chbevl(32.0_f64.mul_add(ax.recip(), -2.0), &BESSI0_COEFFS_B).ln() - - 0.5 * ax.ln() + 0.5_f64.mul_add( + -ax.ln(), + ax + chbevl(32.0_f64.mul_add(ax.recip(), -2.0), &BESSI0_COEFFS_B) + .ln(), + ) } } diff --git a/src/misc/func.rs b/src/misc/func.rs index 2a9e7453..cb80f385 100644 --- a/src/misc/func.rs +++ b/src/misc/func.rs @@ -469,6 +469,7 @@ pub fn sorted_uniforms(n: usize, rng: &mut R) -> Vec { xs } +#[allow(dead_code)] pub(crate) fn eq_or_close(a: f64, b: f64, tol: f64) -> bool { a == b // Really equal, or both -Inf or Inf || a.is_nan() && b.is_nan() // Both NaN diff --git a/src/test.rs b/src/test.rs index 463dcc59..712fac04 100644 --- a/src/test.rs +++ b/src/test.rs @@ -581,14 +581,13 @@ where let mut total_integral = 0.0; for bin_ix in 0..num_bins { - let bin_start = min_val + bin_ix as f64 * bin_width; + let bin_start = (bin_ix as f64).mul_add(bin_width, min_val); let bin_mid = bin_start + bin_width / 2.0; let bin_end = bin_start + bin_width; // Apply Simpson's rule for each bin let integral = (bin_width / 6.0) - * (density_fn(bin_start) - + 4.0 * density_fn(bin_mid) + * (4.0_f64.mul_add(density_fn(bin_mid), density_fn(bin_start)) + density_fn(bin_end)); expected_counts.push(integral); From aaca36a47133c9fa59a4ffb07ae2566ae748afd0 Mon Sep 17 00:00:00 2001 From: Mike Schmidt Date: Tue, 22 Apr 2025 10:20:31 -0600 Subject: [PATCH 17/17] Fixed clippy --fix mistakes --- examples/betaprime_sbc.rs | 16 +++++++--------- examples/sbd.rs | 11 ++++++----- examples/stickbreaking_posterior.rs | 8 +++++--- src/dist/betaprime.rs | 1 + 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/examples/betaprime_sbc.rs b/examples/betaprime_sbc.rs index fda5a39b..aa32564e 100644 --- a/examples/betaprime_sbc.rs +++ b/examples/betaprime_sbc.rs @@ -1,16 +1,14 @@ -#[cfg(feature = "experimental")] -use rand::SeedableRng; -#[cfg(feature = "experimental")] -use rand_xoshiro::Xoshiro256Plus; -#[cfg(feature = "experimental")] -use rv::experimental::stick_breaking_process::{ - StickBreaking, StickBreakingDiscrete, StickBreakingDiscreteSuffStat, -}; - // Simulation-based calibration // For details see http://www.stat.columbia.edu/~gelman/research/unpublished/sbc.pdf #[cfg(feature = "experimental")] fn main() { + use rand::SeedableRng; + use rand_xoshiro::Xoshiro256Plus; + use rv::experimental::stick_breaking_process::{ + StickBreaking, StickBreakingDiscrete, StickBreakingDiscreteSuffStat, + }; + use rv::prelude::*; + let mut rng = Xoshiro256Plus::seed_from_u64(123); let n_samples = 10000; let n_obs = 10; diff --git a/examples/sbd.rs b/examples/sbd.rs index 1161afe3..df7778c1 100644 --- a/examples/sbd.rs +++ b/examples/sbd.rs @@ -1,11 +1,12 @@ -#[cfg(feature = "experimental")] -use rv::experimental::stick_breaking_process::{ - StickBreaking, StickBreakingDiscrete, StickSequence, -}; - fn main() { #[cfg(feature = "experimental")] { + use rand_xoshiro::rand_core::SeedableRng; + use rv::experimental::stick_breaking_process::{ + StickBreaking, StickBreakingDiscrete, StickSequence, + }; + use rv::prelude::*; + // Instantiate a stick-breaking process let alpha = 10.0; let sbp = StickBreaking::new(UnitPowerLaw::new(alpha).unwrap()); diff --git a/examples/stickbreaking_posterior.rs b/examples/stickbreaking_posterior.rs index 51dc3ebb..b8a45ce9 100644 --- a/examples/stickbreaking_posterior.rs +++ b/examples/stickbreaking_posterior.rs @@ -1,9 +1,11 @@ -#[cfg(feature = "experimental")] -use rv::experimental::stick_breaking_process::*; - fn main() { #[cfg(feature = "experimental")] { + use itertools::Either; + use peroxide::fuga::Statistics; + use rv::experimental::stick_breaking_process::*; + use rv::prelude::*; + let mut rng = rand::thread_rng(); let sb = StickBreaking::new(UnitPowerLaw::new(3.0).unwrap()); diff --git a/src/dist/betaprime.rs b/src/dist/betaprime.rs index be2ba15a..25f6a6c1 100644 --- a/src/dist/betaprime.rs +++ b/src/dist/betaprime.rs @@ -389,6 +389,7 @@ impl ConjugatePrior for BetaPrime { ) -> Self { match data { DataOrSuffStat::Data(xs) => { + #[allow(clippy::useless_asref)] let stat = StickBreakingDiscreteSuffStat::from(xs.as_ref()); self.posterior(&DataOrSuffStat::SuffStat(&stat)) }