-
Notifications
You must be signed in to change notification settings - Fork 271
Fix Beta with concentration1=1 gives nan log_prob at value=0 #2089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
numpyro/distributions/continuous.py
Outdated
| return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1)) | ||
| # Handle edge cases where concentration1=1 and value=0, or concentration0=1 and value=1 | ||
| # These cases would result in nan due to log(0) * 0 in the Dirichlet computation | ||
| log_prob = jnp.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting! could you check if grads w.r.t. value, concentration1, concentration0 are not NaN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch 🙈 ! The gradients are NaN with this approach. After some investigation and talking with Claude Code 😅 . We arrived to a solution via https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html, see 13fc288 and 19f8be7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to mask the extreme values by canonical ones before calling direchlet log_prob. Using where is better than custom jvp I think: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok! I tried to use this approach in 1a9eae1
numpyro/distributions/continuous.py
Outdated
| # Step 4: Compute correct forward-pass value at boundaries | ||
| # Use stop_gradient to prevent gradients from flowing through this branch | ||
| # xlogy(0, 0) = 0 gives the correct value when concentration=1 at boundaries | ||
| boundary_log_prob = jax.lax.stop_gradient( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about setting this to 0.0 instead of using stop_gradient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could not make it work without stop_gradient :(
numpyro/distributions/continuous.py
Outdated
| safe_value = jnp.where(is_boundary, 0.5, value) | ||
|
|
||
| # Step 3: Compute log_prob with safe values (gradients flow through here) | ||
| safe_log_prob = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we use self.dirichlet.log_prob here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tried it in c6f113b
Closes #2088