Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rv"
version = "0.18.0"
version = "0.18.1"
authors = ["Baxter Eaves", "Michael Schmidt", "Chad Scherrer"]
description = "Random variables"
repository = "https://github.com/promised-ai/rv"
Expand Down
26 changes: 25 additions & 1 deletion src/misc/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ where
let (alpha, r) =
self.fold((f64::NEG_INFINITY, 0.0), |(alpha, r), x| {
let x = *x.borrow();
if x <= alpha {

if x == f64::NEG_INFINITY {
return (alpha, r);
} else if x <= alpha {
(alpha, r + (x - alpha).exp())
} else {
(x, (alpha - x).exp().mul_add(r, 1.0))
Expand Down Expand Up @@ -864,6 +867,27 @@ mod tests {
assert_eq!(argmax(&xs), vec![4, 6]);
}

#[test]
fn logsumexp_nan_handling() {
let a: f64 = -3.0;
let b: f64 = -7.0;
let target: f64 = logaddexp(a, b);
let xs = [
-f64::INFINITY,
a,
-f64::INFINITY,
b,
-f64::INFINITY,
-f64::INFINITY,
-f64::INFINITY,
-f64::INFINITY,
-f64::INFINITY,
-f64::INFINITY,
];
let result = xs.iter().logsumexp();
assert!((result - target).abs() < 1e-12);
}

proptest! {
#[test]
fn proptest_logsumexp(xs in prop::collection::vec(-1e10_f64..1e10_f64, 0..100)) {
Expand Down
Loading