Skip to content

Conversation

@juanitorduz
Copy link
Collaborator

Closes #2088

@juanitorduz juanitorduz requested a review from fehiepsi October 27, 2025 12:22
@juanitorduz juanitorduz added bug Something isn't working enhancement New feature or request labels Oct 27, 2025
return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1))
# Handle edge cases where concentration1=1 and value=0, or concentration0=1 and value=1
# These cases would result in nan due to log(0) * 0 in the Dirichlet computation
log_prob = jnp.where(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting! could you check if grads w.r.t. value, concentration1, concentration0 are not NaN?

Copy link
Collaborator Author

@juanitorduz juanitorduz Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch 🙈 ! The gradients are NaN with this approach. After some investigation and talking with Claude Code 😅 . We arrived to a solution via https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html, see 13fc288 and 19f8be7

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to mask the extreme values by canonical ones before calling direchlet log_prob. Using where is better than custom jvp I think: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok! I tried to use this approach in 1a9eae1

@juanitorduz juanitorduz requested a review from fehiepsi November 1, 2025 19:34
# Step 4: Compute correct forward-pass value at boundaries
# Use stop_gradient to prevent gradients from flowing through this branch
# xlogy(0, 0) = 0 gives the correct value when concentration=1 at boundaries
boundary_log_prob = jax.lax.stop_gradient(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about setting this to 0.0 instead of using stop_gradient?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not make it work without stop_gradient :(

safe_value = jnp.where(is_boundary, 0.5, value)

# Step 3: Compute log_prob with safe values (gradients flow through here)
safe_log_prob = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use self.dirichlet.log_prob here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried it in c6f113b

@juanitorduz juanitorduz requested a review from fehiepsi November 19, 2025 09:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Beta with concentration1=1 gives nan log_prob at value=0

3 participants